@@ -22,6 +22,7 @@ use core::ops::{FnMut, Deref};
22
22
use key:: { SecretKey , PublicKey } ;
23
23
use ffi:: { self , CPtr } ;
24
24
use secp256k1_sys:: types:: { c_int, c_uchar, c_void} ;
25
+ use Error ;
25
26
26
27
/// A tag used for recovering the public key from a compact signature
27
28
#[ derive( Copy , Clone ) ]
@@ -63,6 +64,7 @@ impl SharedSecret {
63
64
64
65
/// Set the length of the object.
65
66
pub ( crate ) fn set_len ( & mut self , len : usize ) {
67
+ debug_assert ! ( len <= self . data. len( ) ) ;
66
68
self . len = len;
67
69
}
68
70
}
@@ -87,19 +89,29 @@ impl Deref for SharedSecret {
87
89
}
88
90
89
91
92
+ #[ cfg( feature = "std" ) ]
90
93
unsafe extern "C" fn hash_callback < F > ( output : * mut c_uchar , x : * const c_uchar , y : * const c_uchar , data : * mut c_void ) -> c_int
91
94
where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret {
92
- let callback: & mut F = & mut * ( data as * mut F ) ;
93
95
94
- let mut x_arr = [ 0 ; 32 ] ;
95
- let mut y_arr = [ 0 ; 32 ] ;
96
- ptr:: copy_nonoverlapping ( x, x_arr. as_mut_ptr ( ) , 32 ) ;
97
- ptr:: copy_nonoverlapping ( y, y_arr. as_mut_ptr ( ) , 32 ) ;
96
+ use std:: panic:: catch_unwind;
97
+ let res = catch_unwind ( || {
98
+ let callback: & mut F = & mut * ( data as * mut F ) ;
98
99
99
- let secret = callback ( x_arr, y_arr) ;
100
- ptr:: copy_nonoverlapping ( secret. as_ptr ( ) , output as * mut u8 , secret. len ( ) ) ;
100
+ let mut x_arr = [ 0 ; 32 ] ;
101
+ let mut y_arr = [ 0 ; 32 ] ;
102
+ ptr:: copy_nonoverlapping ( x, x_arr. as_mut_ptr ( ) , 32 ) ;
103
+ ptr:: copy_nonoverlapping ( y, y_arr. as_mut_ptr ( ) , 32 ) ;
101
104
102
- secret. len ( ) as c_int
105
+ let secret = callback ( x_arr, y_arr) ;
106
+ ptr:: copy_nonoverlapping ( secret. as_ptr ( ) , output as * mut u8 , secret. len ( ) ) ;
107
+
108
+ secret. len ( ) as c_int
109
+ } ) ;
110
+ if let Ok ( len) = res {
111
+ len
112
+ } else {
113
+ -1
114
+ }
103
115
}
104
116
105
117
@@ -140,7 +152,8 @@ impl SharedSecret {
140
152
/// });
141
153
///
142
154
/// ```
143
- pub fn new_with_hash < F > ( point : & PublicKey , scalar : & SecretKey , mut hash_function : F ) -> SharedSecret
155
+ #[ cfg( feature = "std" ) ]
156
+ pub fn new_with_hash < F > ( point : & PublicKey , scalar : & SecretKey , mut hash_function : F ) -> Result < SharedSecret , Error >
144
157
where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret
145
158
{
146
159
let mut ss = SharedSecret :: empty ( ) ;
@@ -156,9 +169,12 @@ impl SharedSecret {
156
169
& mut hash_function as * mut F as * mut c_void ,
157
170
)
158
171
} ;
172
+ if res == -1 {
173
+ return Err ( Error :: CallbackPanicked ) ;
174
+ }
159
175
debug_assert ! ( res >= 16 ) ; // 128 bit is the minimum for a secure hash function and the minimum we let users.
160
176
ss. set_len ( res as usize ) ;
161
- ss
177
+ Ok ( ss )
162
178
}
163
179
}
164
180
@@ -167,6 +183,7 @@ mod tests {
167
183
use rand:: thread_rng;
168
184
use super :: SharedSecret ;
169
185
use super :: super :: Secp256k1 ;
186
+ use Error ;
170
187
171
188
#[ test]
172
189
fn ecdh ( ) {
@@ -187,9 +204,9 @@ mod tests {
187
204
let ( sk1, pk1) = s. generate_keypair ( & mut thread_rng ( ) ) ;
188
205
let ( sk2, pk2) = s. generate_keypair ( & mut thread_rng ( ) ) ;
189
206
190
- let sec1 = SharedSecret :: new_with_hash ( & pk1, & sk2, |x, _| x. into ( ) ) ;
191
- let sec2 = SharedSecret :: new_with_hash ( & pk2, & sk1, |x, _| x. into ( ) ) ;
192
- let sec_odd = SharedSecret :: new_with_hash ( & pk1, & sk1, |x, _| x. into ( ) ) ;
207
+ let sec1 = SharedSecret :: new_with_hash ( & pk1, & sk2, |x, _| x. into ( ) ) . unwrap ( ) ;
208
+ let sec2 = SharedSecret :: new_with_hash ( & pk2, & sk1, |x, _| x. into ( ) ) . unwrap ( ) ;
209
+ let sec_odd = SharedSecret :: new_with_hash ( & pk1, & sk1, |x, _| x. into ( ) ) . unwrap ( ) ;
193
210
assert_eq ! ( sec1, sec2) ;
194
211
assert_ne ! ( sec_odd, sec2) ;
195
212
}
@@ -205,11 +222,23 @@ mod tests {
205
222
x_out = x;
206
223
y_out = y;
207
224
expect_result. into ( )
208
- } ) ;
225
+ } ) . unwrap ( ) ;
209
226
assert_eq ! ( & expect_result[ ..] , & result[ ..] ) ;
210
227
assert_ne ! ( x_out, [ 0u8 ; 32 ] ) ;
211
228
assert_ne ! ( y_out, [ 0u8 ; 32 ] ) ;
212
229
}
230
+
231
+ #[ test]
232
+ fn ecdh_with_hash_callback_panic ( ) {
233
+ let s = Secp256k1 :: signing_only ( ) ;
234
+ let ( sk1, pk1) = s. generate_keypair ( & mut thread_rng ( ) ) ;
235
+ let mut res = [ 0u8 ; 48 ] ;
236
+ let result = SharedSecret :: new_with_hash ( & pk1, & sk1, | x, _ | {
237
+ res. copy_from_slice ( & x) ; // res.len() != x.len(). this will panic.
238
+ res. into ( )
239
+ } ) ;
240
+ assert_eq ! ( result, Err ( Error :: CallbackPanicked ) ) ;
241
+ }
213
242
}
214
243
215
244
#[ cfg( all( test, feature = "unstable" ) ) ]
0 commit comments