Skip to content

Commit e94c531

Browse files
committed
fixup! Optimize compat_fn!, add compat_fn_lazy!, add unicows support
1 parent 5e0572e commit e94c531

File tree

1 file changed

+41
-6
lines changed

1 file changed

+41
-6
lines changed

library/std/src/sys/windows/compat.rs

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,25 @@ impl Module {
112112
///
113113
/// This should only be use for modules that exist for the lifetime of std
114114
/// (e.g. kernel32 and ntdll).
115-
pub unsafe fn new(name: &CStr) -> Option<Self> {
115+
pub unsafe fn get(name: &CStr) -> Option<Self> {
116116
// SAFETY: A CStr is always null terminated.
117117
let module = c::GetModuleHandleA(name.as_ptr().cast::<u8>());
118118
NonNull::new(module).map(Self)
119119
}
120120

121+
/// Try to get a handle to a loaded module. If the module was not loaded, then load
122+
/// the module.
123+
///
124+
/// # SAFETY
125+
///
126+
/// This should only be used for modules that exist for the lifetime of std
127+
/// that are not pre-loaded by Windows (e.g. advapi32 and bcrypt)
128+
pub unsafe fn load(name: &CStr) -> Option<Self> {
129+
// SAFETY: A CStr is always null terminated.
130+
let module = c::LoadLibraryA(name.as_ptr().cast::<u8>());
131+
NonNull::new(module).map(Self)
132+
}
133+
121134
// Try to get the address of a function.
122135
pub fn proc_address(self, name: &CStr) -> Option<NonNull<c_void>> {
123136
unsafe {
@@ -135,7 +148,8 @@ pub(crate) const UNICOWS_MODULE_NAME: &CStr = ansi_str!("unicows");
135148

136149
/// Load a function or use a fallback implementation if that fails.
137150
macro_rules! compat_fn_with_fallback {
138-
(pub static $module:ident: &CStr = $name:expr; $(
151+
(pub static $module:ident: &CStr = $name:expr;
152+
const LOAD: bool = $load:expr; $(
139153
$(#[$meta:meta])*
140154
$vis:vis fn $symbol:ident($($argname:ident: $argtype:ty),*) -> $rettype:ty $fallback_body:block
141155
)*) => (
@@ -160,14 +174,22 @@ macro_rules! compat_fn_with_fallback {
160174
static PTR: AtomicPtr<c_void> = AtomicPtr::new(load as *mut _);
161175

162176
unsafe extern "system" fn load($($argname: $argtype),*) -> $rettype {
163-
let func = load_from_module(Module::new($module));
177+
let func = load_from_module(load_module($module));
164178
func($($argname),*)
165179
}
166180

181+
fn load_module() -> Option<Module> {
182+
if $load {
183+
Module::load($module)
184+
} else {
185+
Module::get($module)
186+
}
187+
}
188+
167189
fn load_from_module(module: Option<Module>) -> F {
168190
unsafe {
169191
static SYMBOL_NAME: &CStr = ansi_str!(sym $symbol);
170-
if let Some(f) = Module::new($crate::sys::compat::UNICOWS_MODULE_NAME)
192+
if let Some(f) = Module::get($crate::sys::compat::UNICOWS_MODULE_NAME)
171193
.and_then(|m| m.proc_address(SYMBOL_NAME))
172194
.or_else(|| module.and_then(|m| m.proc_address(SYMBOL_NAME)))
173195
{
@@ -196,7 +218,7 @@ macro_rules! compat_fn_with_fallback {
196218
}
197219

198220
// Otherwise, the function pointer should be resolved
199-
let _ = load_from_module(Module::new($module));
221+
let _ = load_from_module(load_module($module));
200222

201223
// After resolution, the function pointer will only point to `fallback` if
202224
// the target function was not found
@@ -212,6 +234,19 @@ macro_rules! compat_fn_with_fallback {
212234
$(#[$meta])*
213235
$vis use $symbol::call as $symbol;
214236
)*)
237+
(pub static $module:ident: &CStr = $name:expr; $(
238+
$(#[$meta:meta])*
239+
$vis:vis fn $symbol:ident($($argname:ident: $argtype:ty),*) -> $rettype:ty $fallback_body:block
240+
)*) => (
241+
compat_fn_with_fallback! {
242+
pub static $module: &CStr = $name;
243+
const LOAD: bool = false;
244+
$(
245+
$(#[$meta])*
246+
$vis fn $symbol($($argname: $argtype),*) -> $rettype $fallback_body
247+
)*
248+
}
249+
)
215250
}
216251

217252
/// Optionally loaded functions.
@@ -257,7 +292,7 @@ pub(super) fn load_synch_functions() {
257292

258293
// Try loading the library and all the required functions.
259294
// If any step fails, then they all fail.
260-
let library = unsafe { Module::new(MODULE_NAME) }?;
295+
let library = unsafe { Module::get(MODULE_NAME) }?;
261296
let wait_on_address = library.proc_address(WAIT_ON_ADDRESS)?;
262297
let wake_by_address_single = library.proc_address(WAKE_BY_ADDRESS_SINGLE)?;
263298

0 commit comments

Comments
 (0)