Skip to content

Commit 5619f2a

Browse files
committed
Add an unsafe variant of new_with_has called new_with_hash_no_panic
1 parent 124c1f3 commit 5619f2a

File tree

1 file changed

+84
-34
lines changed

1 file changed

+84
-34
lines changed

src/ecdh.rs

Lines changed: 84 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -89,31 +89,38 @@ impl Deref for SharedSecret {
8989
}
9090

9191

92-
#[cfg(feature = "std")]
93-
unsafe extern "C" fn hash_callback<F>(output: *mut c_uchar, x: *const c_uchar, y: *const c_uchar, data: *mut c_void) -> c_int
92+
unsafe fn callback_logic<F>(output: *mut c_uchar, x: *const c_uchar, y: *const c_uchar, data: *mut c_void) -> c_int
9493
where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret {
94+
let callback: &mut F = &mut *(data as *mut F);
95+
96+
let mut x_arr = [0; 32];
97+
let mut y_arr = [0; 32];
98+
ptr::copy_nonoverlapping(x, x_arr.as_mut_ptr(), 32);
99+
ptr::copy_nonoverlapping(y, y_arr.as_mut_ptr(), 32);
95100

96-
use std::panic::catch_unwind;
97-
let res = catch_unwind(|| {
98-
let callback: &mut F = &mut *(data as *mut F);
101+
let secret = callback(x_arr, y_arr);
102+
ptr::copy_nonoverlapping(secret.as_ptr(), output as *mut u8, secret.len());
99103

100-
let mut x_arr = [0; 32];
101-
let mut y_arr = [0; 32];
102-
ptr::copy_nonoverlapping(x, x_arr.as_mut_ptr(), 32);
103-
ptr::copy_nonoverlapping(y, y_arr.as_mut_ptr(), 32);
104+
secret.len() as c_int
105+
}
104106

105-
let secret = callback(x_arr, y_arr);
106-
ptr::copy_nonoverlapping(secret.as_ptr(), output as *mut u8, secret.len());
107+
#[cfg(feature = "std")]
108+
unsafe extern "C" fn hash_callback_catch_unwind<F>(output: *mut c_uchar, x: *const c_uchar, y: *const c_uchar, data: *mut c_void) -> c_int
109+
where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret {
107110

108-
secret.len() as c_int
109-
});
111+
let res = ::std::panic::catch_unwind(||callback_logic::<F>(output, x, y, data));
110112
if let Ok(len) = res {
111113
len
112114
} else {
113115
-1
114116
}
115117
}
116118

119+
unsafe extern "C" fn hash_callback_unsafe<F>(output: *mut c_uchar, x: *const c_uchar, y: *const c_uchar, data: *mut c_void) -> c_int
120+
where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret {
121+
callback_logic::<F>(output, x, y, data)
122+
}
123+
117124

118125
impl SharedSecret {
119126
/// Creates a new shared secret from a pubkey and secret key
@@ -135,6 +142,29 @@ impl SharedSecret {
135142
ss
136143
}
137144

145+
fn new_with_callback_internal<F>(point: &PublicKey, scalar: &SecretKey, mut closure: F, callback: ffi::EcdhHashFn) -> Result<SharedSecret, Error>
146+
where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret {
147+
let mut ss = SharedSecret::empty();
148+
149+
let res = unsafe {
150+
ffi::secp256k1_ecdh(
151+
ffi::secp256k1_context_no_precomp,
152+
ss.get_data_mut_ptr(),
153+
point.as_ptr(),
154+
scalar.as_ptr(),
155+
callback,
156+
&mut closure as *mut F as *mut c_void,
157+
)
158+
};
159+
if res == -1 {
160+
return Err(Error::CallbackPanicked);
161+
}
162+
debug_assert!(res >= 16); // 128 bit is the minimum for a secure hash function and the minimum we let users.
163+
ss.set_len(res as usize);
164+
Ok(ss)
165+
166+
}
167+
138168
/// Creates a new shared secret from a pubkey and secret key with applied custom hash function
139169
/// # Examples
140170
/// ```
@@ -153,28 +183,42 @@ impl SharedSecret {
153183
///
154184
/// ```
155185
#[cfg(feature = "std")]
156-
pub fn new_with_hash<F>(point: &PublicKey, scalar: &SecretKey, mut hash_function: F) -> Result<SharedSecret, Error>
157-
where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret
158-
{
159-
let mut ss = SharedSecret::empty();
160-
let hashfp: ffi::EcdhHashFn = hash_callback::<F>;
186+
pub fn new_with_hash<F>(point: &PublicKey, scalar: &SecretKey, hash_function: F) -> Result<SharedSecret, Error>
187+
where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret {
188+
Self::new_with_callback_internal(point, scalar, hash_function, hash_callback_catch_unwind::<F>)
189+
}
161190

162-
let res = unsafe {
163-
ffi::secp256k1_ecdh(
164-
ffi::secp256k1_context_no_precomp,
165-
ss.get_data_mut_ptr(),
166-
point.as_ptr(),
167-
scalar.as_ptr(),
168-
hashfp,
169-
&mut hash_function as *mut F as *mut c_void,
170-
)
171-
};
172-
if res == -1 {
173-
return Err(Error::CallbackPanicked);
174-
}
175-
debug_assert!(res >= 16); // 128 bit is the minimum for a secure hash function and the minimum we let users.
176-
ss.set_len(res as usize);
177-
Ok(ss)
191+
/// Creates a new shared secret from a pubkey and secret key with applied custom hash function
192+
/// Note that this function is the same as [`new_with_hash`]
193+
///
194+
/// # Safety
195+
/// The function doesn't wrap the callback with [`catch_unwind`]
196+
/// so if the callback panics it will panic through an FFI boundray which is [`Undefined Behavior`]
197+
/// If possible you should use [`new_with_hash`] which does wrap the callback with [`catch_unwind`] so is safe to use.
198+
///
199+
/// [`catch_unwind`]: https://doc.rust-lang.org/std/panic/fn.catch_unwind.html
200+
/// [`Undefined Behavior`]: https://doc.rust-lang.org/nomicon/ffi.html#ffi-and-panics
201+
/// [`new_with_hash`]: #method.new_with_hash
202+
/// # Examples
203+
/// ```
204+
/// # use secp256k1::ecdh::SharedSecret;
205+
/// # use secp256k1::{Secp256k1, PublicKey, SecretKey};
206+
/// # fn sha2(_a: &[u8], _b: &[u8]) -> [u8; 32] {[0u8; 32]}
207+
/// # let secp = Secp256k1::signing_only();
208+
/// # let secret_key = SecretKey::from_slice(&[3u8; 32]).unwrap();
209+
/// # let secret_key2 = SecretKey::from_slice(&[7u8; 32]).unwrap();
210+
/// # let public_key = PublicKey::from_secret_key(&secp, &secret_key2);
211+
//
212+
/// let secret = unsafe { SharedSecret::new_with_hash_no_panic(&public_key, &secret_key, |x,y| {
213+
/// let hash: [u8; 32] = sha2(&x,&y);
214+
/// hash.into()
215+
/// })};
216+
///
217+
///
218+
/// ```
219+
pub unsafe fn new_with_hash_no_panic<F>(point: &PublicKey, scalar: &SecretKey, hash_function: F) -> Result<SharedSecret, Error>
220+
where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret {
221+
Self::new_with_callback_internal(point, scalar, hash_function, hash_callback_unsafe::<F>)
178222
}
179223
}
180224

@@ -223,7 +267,13 @@ mod tests {
223267
y_out = y;
224268
expect_result.into()
225269
}).unwrap();
270+
let result_unsafe = unsafe {SharedSecret::new_with_hash_no_panic(&pk1, &sk1, | x, y | {
271+
x_out = x;
272+
y_out = y;
273+
expect_result.into()
274+
}).unwrap()};
226275
assert_eq!(&expect_result[..], &result[..]);
276+
assert_eq!(result, result_unsafe);
227277
assert_ne!(x_out, [0u8; 32]);
228278
assert_ne!(y_out, [0u8; 32]);
229279
}

0 commit comments

Comments
 (0)