@@ -89,31 +89,38 @@ impl Deref for SharedSecret {
89
89
}
90
90
91
91
92
- #[ cfg( feature = "std" ) ]
93
- unsafe extern "C" fn hash_callback < F > ( output : * mut c_uchar , x : * const c_uchar , y : * const c_uchar , data : * mut c_void ) -> c_int
92
+ unsafe fn callback_logic < F > ( output : * mut c_uchar , x : * const c_uchar , y : * const c_uchar , data : * mut c_void ) -> c_int
94
93
where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret {
94
+ let callback: & mut F = & mut * ( data as * mut F ) ;
95
+
96
+ let mut x_arr = [ 0 ; 32 ] ;
97
+ let mut y_arr = [ 0 ; 32 ] ;
98
+ ptr:: copy_nonoverlapping ( x, x_arr. as_mut_ptr ( ) , 32 ) ;
99
+ ptr:: copy_nonoverlapping ( y, y_arr. as_mut_ptr ( ) , 32 ) ;
95
100
96
- use std:: panic:: catch_unwind;
97
- let res = catch_unwind ( || {
98
- let callback: & mut F = & mut * ( data as * mut F ) ;
101
+ let secret = callback ( x_arr, y_arr) ;
102
+ ptr:: copy_nonoverlapping ( secret. as_ptr ( ) , output as * mut u8 , secret. len ( ) ) ;
99
103
100
- let mut x_arr = [ 0 ; 32 ] ;
101
- let mut y_arr = [ 0 ; 32 ] ;
102
- ptr:: copy_nonoverlapping ( x, x_arr. as_mut_ptr ( ) , 32 ) ;
103
- ptr:: copy_nonoverlapping ( y, y_arr. as_mut_ptr ( ) , 32 ) ;
104
+ secret. len ( ) as c_int
105
+ }
104
106
105
- let secret = callback ( x_arr, y_arr) ;
106
- ptr:: copy_nonoverlapping ( secret. as_ptr ( ) , output as * mut u8 , secret. len ( ) ) ;
107
+ #[ cfg( feature = "std" ) ]
108
+ unsafe extern "C" fn hash_callback_catch_unwind < F > ( output : * mut c_uchar , x : * const c_uchar , y : * const c_uchar , data : * mut c_void ) -> c_int
109
+ where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret {
107
110
108
- secret. len ( ) as c_int
109
- } ) ;
111
+ let res = :: std:: panic:: catch_unwind ( ||callback_logic :: < F > ( output, x, y, data) ) ;
110
112
if let Ok ( len) = res {
111
113
len
112
114
} else {
113
115
-1
114
116
}
115
117
}
116
118
119
+ unsafe extern "C" fn hash_callback_unsafe < F > ( output : * mut c_uchar , x : * const c_uchar , y : * const c_uchar , data : * mut c_void ) -> c_int
120
+ where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret {
121
+ callback_logic :: < F > ( output, x, y, data)
122
+ }
123
+
117
124
118
125
impl SharedSecret {
119
126
/// Creates a new shared secret from a pubkey and secret key
@@ -135,6 +142,29 @@ impl SharedSecret {
135
142
ss
136
143
}
137
144
145
+ fn new_with_callback_internal < F > ( point : & PublicKey , scalar : & SecretKey , mut closure : F , callback : ffi:: EcdhHashFn ) -> Result < SharedSecret , Error >
146
+ where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret {
147
+ let mut ss = SharedSecret :: empty ( ) ;
148
+
149
+ let res = unsafe {
150
+ ffi:: secp256k1_ecdh (
151
+ ffi:: secp256k1_context_no_precomp,
152
+ ss. get_data_mut_ptr ( ) ,
153
+ point. as_ptr ( ) ,
154
+ scalar. as_ptr ( ) ,
155
+ callback,
156
+ & mut closure as * mut F as * mut c_void ,
157
+ )
158
+ } ;
159
+ if res == -1 {
160
+ return Err ( Error :: CallbackPanicked ) ;
161
+ }
162
+ debug_assert ! ( res >= 16 ) ; // 128 bit is the minimum for a secure hash function and the minimum we let users.
163
+ ss. set_len ( res as usize ) ;
164
+ Ok ( ss)
165
+
166
+ }
167
+
138
168
/// Creates a new shared secret from a pubkey and secret key with applied custom hash function
139
169
/// # Examples
140
170
/// ```
@@ -153,28 +183,42 @@ impl SharedSecret {
153
183
///
154
184
/// ```
155
185
#[ cfg( feature = "std" ) ]
156
- pub fn new_with_hash < F > ( point : & PublicKey , scalar : & SecretKey , mut hash_function : F ) -> Result < SharedSecret , Error >
157
- where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret
158
- {
159
- let mut ss = SharedSecret :: empty ( ) ;
160
- let hashfp: ffi:: EcdhHashFn = hash_callback :: < F > ;
186
+ pub fn new_with_hash < F > ( point : & PublicKey , scalar : & SecretKey , hash_function : F ) -> Result < SharedSecret , Error >
187
+ where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret {
188
+ Self :: new_with_callback_internal ( point, scalar, hash_function, hash_callback_catch_unwind :: < F > )
189
+ }
161
190
162
- let res = unsafe {
163
- ffi:: secp256k1_ecdh (
164
- ffi:: secp256k1_context_no_precomp,
165
- ss. get_data_mut_ptr ( ) ,
166
- point. as_ptr ( ) ,
167
- scalar. as_ptr ( ) ,
168
- hashfp,
169
- & mut hash_function as * mut F as * mut c_void ,
170
- )
171
- } ;
172
- if res == -1 {
173
- return Err ( Error :: CallbackPanicked ) ;
174
- }
175
- debug_assert ! ( res >= 16 ) ; // 128 bit is the minimum for a secure hash function and the minimum we let users.
176
- ss. set_len ( res as usize ) ;
177
- Ok ( ss)
191
+ /// Creates a new shared secret from a pubkey and secret key with applied custom hash function
192
+ /// Note that this function is the same as [`new_with_hash`]
193
+ ///
194
+ /// # Safety
195
+ /// The function doesn't wrap the callback with [`catch_unwind`]
196
+ /// so if the callback panics it will panic through an FFI boundray which is [`Undefined Behavior`]
197
+ /// If possible you should use [`new_with_hash`] which does wrap the callback with [`catch_unwind`] so is safe to use.
198
+ ///
199
+ /// [`catch_unwind`]: https://doc.rust-lang.org/std/panic/fn.catch_unwind.html
200
+ /// [`Undefined Behavior`]: https://doc.rust-lang.org/nomicon/ffi.html#ffi-and-panics
201
+ /// [`new_with_hash`]: #method.new_with_hash
202
+ /// # Examples
203
+ /// ```
204
+ /// # use secp256k1::ecdh::SharedSecret;
205
+ /// # use secp256k1::{Secp256k1, PublicKey, SecretKey};
206
+ /// # fn sha2(_a: &[u8], _b: &[u8]) -> [u8; 32] {[0u8; 32]}
207
+ /// # let secp = Secp256k1::signing_only();
208
+ /// # let secret_key = SecretKey::from_slice(&[3u8; 32]).unwrap();
209
+ /// # let secret_key2 = SecretKey::from_slice(&[7u8; 32]).unwrap();
210
+ /// # let public_key = PublicKey::from_secret_key(&secp, &secret_key2);
211
+ //
212
+ /// let secret = unsafe { SharedSecret::new_with_hash_no_panic(&public_key, &secret_key, |x,y| {
213
+ /// let hash: [u8; 32] = sha2(&x,&y);
214
+ /// hash.into()
215
+ /// })};
216
+ ///
217
+ ///
218
+ /// ```
219
+ pub unsafe fn new_with_hash_no_panic < F > ( point : & PublicKey , scalar : & SecretKey , hash_function : F ) -> Result < SharedSecret , Error >
220
+ where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret {
221
+ Self :: new_with_callback_internal ( point, scalar, hash_function, hash_callback_unsafe :: < F > )
178
222
}
179
223
}
180
224
@@ -223,7 +267,13 @@ mod tests {
223
267
y_out = y;
224
268
expect_result. into ( )
225
269
} ) . unwrap ( ) ;
270
+ let result_unsafe = unsafe { SharedSecret :: new_with_hash_no_panic ( & pk1, & sk1, | x, y | {
271
+ x_out = x;
272
+ y_out = y;
273
+ expect_result. into ( )
274
+ } ) . unwrap ( ) } ;
226
275
assert_eq ! ( & expect_result[ ..] , & result[ ..] ) ;
276
+ assert_eq ! ( result, result_unsafe) ;
227
277
assert_ne ! ( x_out, [ 0u8 ; 32 ] ) ;
228
278
assert_ne ! ( y_out, [ 0u8 ; 32 ] ) ;
229
279
}
0 commit comments