Skip to content

Commit ec9cbe7

Browse files
committed
Add more compat_fn machinery
1 parent 54a13f1 commit ec9cbe7

File tree

4 files changed

+170
-25
lines changed

4 files changed

+170
-25
lines changed

library/std/src/sys/windows/c.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#![unstable(issue = "none", feature = "windows_c")]
66
#![allow(clippy::style)]
77

8-
use crate::ffi::CStr;
98
use crate::mem;
109
pub use crate::os::raw::c_int;
1110
use crate::os::raw::{c_char, c_long, c_longlong, c_uint, c_ulong, c_ushort, c_void};
@@ -324,7 +323,7 @@ pub unsafe fn NtWriteFile(
324323
// Functions that aren't available on every version of Windows that we support,
325324
// but we still use them and just provide some form of a fallback implementation.
326325
compat_fn_with_fallback! {
327-
pub static KERNEL32: &CStr = c"kernel32";
326+
pub static KERNEL32: &CStr = c"kernel32" => { load: false, unicows: false };
328327

329328
// >= Win10 1607
330329
// https://docs.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-setthreaddescription
@@ -357,7 +356,7 @@ compat_fn_optional! {
357356
}
358357

359358
compat_fn_with_fallback! {
360-
pub static NTDLL: &CStr = c"ntdll";
359+
pub static NTDLL: &CStr = c"ntdll" => { load: true, unicows: false };
361360

362361
pub fn NtCreateKeyedEvent(
363362
KeyedEventHandle: LPHANDLE,

library/std/src/sys/windows/c/windows_sys.lst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2594,3 +2594,4 @@ Windows.Win32.System.WindowsProgramming.PROGRESS_CONTINUE
25942594
Windows.Win32.UI.Shell.GetUserProfileDirectoryW
25952595
// tidy-alphabetical-end
25962596

2597+
Windows.Win32.System.LibraryLoader.LoadLibraryA

library/std/src/sys/windows/c/windows_sys.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,10 @@ extern "system" {
394394
) -> BOOL;
395395
}
396396
#[link(name = "kernel32")]
397+
extern "system" {
398+
pub fn LoadLibraryA(lplibfilename: PCSTR) -> HMODULE;
399+
}
400+
#[link(name = "kernel32")]
397401
extern "system" {
398402
pub fn MoveFileExW(
399403
lpexistingfilename: PCWSTR,

library/std/src/sys/windows/compat.rs

Lines changed: 163 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@ impl Module {
115115
NonNull::new(module).map(Self)
116116
}
117117

118+
#[allow(dead_code)]
119+
pub unsafe fn load(name: &CStr) -> Option<Self> {
120+
// SAFETY: A CStr is always null terminated.
121+
let module = c::LoadLibraryA(name.as_ptr().cast::<u8>());
122+
NonNull::new(module).map(Self)
123+
}
124+
118125
// Try to get the address of a function.
119126
pub fn proc_address(self, name: &CStr) -> Option<NonNull<c_void>> {
120127
unsafe {
@@ -128,24 +135,28 @@ impl Module {
128135
}
129136
}
130137

138+
pub static UNICOWS: &CStr = c"unicows";
139+
131140
/// Load a function or use a fallback implementation if that fails.
132141
macro_rules! compat_fn_with_fallback {
133-
(pub static $module:ident: &CStr = $name:expr; $(
134-
$(#[$meta:meta])*
135-
$vis:vis fn $symbol:ident($($argname:ident: $argtype:ty),*) -> $rettype:ty $fallback_body:block
136-
)*) => (
137-
pub static $module: &CStr = $name;
142+
{
143+
pub static $module:ident: &CStr = $name:expr => { load: $load:expr, unicows: $unicows:expr };
144+
$(
145+
$(#[$meta:meta])*
146+
$vis:vis fn $symbol:ident($($argname:ident: $argtype:ty),* $(,)?) $(-> $rettype:ty)? $fallback_body:block
147+
)+
148+
} => {
138149
$(
139150
$(#[$meta])*
140151
pub mod $symbol {
141152
#[allow(unused_imports)]
142153
use super::*;
143154
use crate::mem;
144-
use crate::ffi::CStr;
155+
use crate::ffi::{CStr, c_void};
145156
use crate::sync::atomic::{AtomicPtr, Ordering};
146-
use crate::sys::compat::Module;
157+
use crate::sys::compat::{Module, UNICOWS};
147158

148-
type F = unsafe extern "system" fn($($argtype),*) -> $rettype;
159+
type F = unsafe extern "system" fn($($argtype),*) $(-> $rettype)?;
149160

150161
/// `PTR` contains a function pointer to one of three functions.
151162
/// It starts with the `load` function.
@@ -154,15 +165,30 @@ macro_rules! compat_fn_with_fallback {
154165
/// If it fails, then `PTR` is set to `fallback`.
155166
static PTR: AtomicPtr<c_void> = AtomicPtr::new(load as *mut _);
156167

157-
unsafe extern "system" fn load($($argname: $argtype),*) -> $rettype {
158-
let func = load_from_module(Module::new($module));
168+
unsafe extern "system" fn load($($argname: $argtype),*) $(-> $rettype)? {
169+
let func = load_from_module();
159170
func($($argname),*)
160171
}
161172

162-
fn load_from_module(module: Option<Module>) -> F {
173+
fn load_from_module() -> F {
163174
unsafe {
164175
static SYMBOL_NAME: &CStr = ansi_str!(sym $symbol);
165-
if let Some(f) = module.and_then(|m| m.proc_address(SYMBOL_NAME)) {
176+
177+
let in_unicows = if $unicows {
178+
Module::new(UNICOWS).and_then(|m| m.proc_address(SYMBOL_NAME))
179+
} else {
180+
None
181+
};
182+
183+
let f = in_unicows.or_else(|| {
184+
if $load {
185+
Module::new($name)
186+
} else {
187+
Module::load($name)
188+
}.and_then(|m| m.proc_address(SYMBOL_NAME))
189+
});
190+
191+
if let Some(f) = f {
166192
PTR.store(f.as_ptr(), Ordering::Relaxed);
167193
mem::transmute(f)
168194
} else {
@@ -172,20 +198,31 @@ macro_rules! compat_fn_with_fallback {
172198
}
173199
}
174200

201+
#[allow(dead_code)]
202+
pub fn available() -> bool {
203+
let mut ptr = PTR.load(Ordering::Relaxed);
204+
if ptr == load as *mut _ {
205+
ptr = load_from_module() as *mut _;
206+
}
207+
208+
ptr != fallback as *mut _
209+
}
210+
175211
#[allow(unused_variables)]
176-
unsafe extern "system" fn fallback($($argname: $argtype),*) -> $rettype {
212+
unsafe extern "system" fn fallback($($argname: $argtype),*) $(-> $rettype)? {
177213
$fallback_body
178214
}
179215

180216
#[inline(always)]
181-
pub unsafe fn call($($argname: $argtype),*) -> $rettype {
217+
pub unsafe fn call($($argname: $argtype),*) $(-> $rettype)? {
182218
let func: F = mem::transmute(PTR.load(Ordering::Relaxed));
183219
func($($argname),*)
184220
}
185221
}
186222
$(#[$meta])*
187223
$vis use $symbol::call as $symbol;
188-
)*)
224+
)*
225+
}
189226
}
190227

191228
/// Optionally loaded functions.
@@ -195,7 +232,7 @@ macro_rules! compat_fn_optional {
195232
($load_functions:expr;
196233
$(
197234
$(#[$meta:meta])*
198-
$vis:vis fn $symbol:ident($($argname:ident: $argtype:ty),*) $(-> $rettype:ty)?;
235+
$vis:vis fn $symbol:ident($($argname:ident: $argtype:ty),* $(,)?) $(-> $rettype:ty)?;
199236
)+) => (
200237
$(
201238
pub mod $symbol {
@@ -211,32 +248,136 @@ macro_rules! compat_fn_optional {
211248
type F = unsafe extern "system" fn($($argtype),*) $(-> $rettype)?;
212249

213250
#[inline(always)]
251+
#[allow(dead_code)]
214252
pub fn option() -> Option<F> {
215253
// Miri does not understand the way we do preloading
216254
// therefore load the function here instead.
217255
#[cfg(miri)] $load_functions;
218256
NonNull::new(PTR.load(Ordering::Relaxed)).map(|f| unsafe { mem::transmute(f) })
219257
}
258+
259+
#[inline(always)]
260+
#[allow(dead_code)]
261+
pub unsafe fn call($($argname: $argtype),*) $(-> $rettype)? {
262+
(mem::transmute::<_, F>(PTR.load(Ordering::Relaxed)))($($argname),*)
263+
}
220264
}
265+
266+
#[allow(unused_imports)]
267+
$(#[$meta])*
268+
$vis use $symbol::call as $symbol;
221269
)+
222270
)
223271
}
224272

273+
macro_rules! compat_fn_lazy {
274+
{
275+
pub static $module:ident: &CStr = $name:expr => { load: $load:expr, unicows: $unicows:expr };
276+
$(
277+
$(#[$meta:meta])*
278+
$vis:vis fn $symbol:ident($($argname:ident: $argtype:ty),* $(,)?) $(-> $rettype:ty)?;
279+
)+
280+
} => {
281+
$(
282+
$(#[$meta])*
283+
pub mod $symbol {
284+
#[allow(unused_imports)]
285+
use super::*;
286+
use crate::mem;
287+
use crate::ffi::{CStr, c_void};
288+
use crate::sync::atomic::{AtomicPtr, Ordering};
289+
use crate::sys::compat::{Module, UNICOWS};
290+
291+
type F = unsafe extern "system" fn($($argtype),*) $(-> $rettype)?;
292+
293+
/// `PTR` contains a function pointer to one of three functions.
294+
/// It starts with the `load` function.
295+
/// When that is called it attempts to load the requested symbol.
296+
/// If it succeeds, `PTR` is set to the address of that symbol.
297+
/// If it fails, then `PTR` is set to `fallback`.
298+
static PTR: AtomicPtr<c_void> = AtomicPtr::new(load as *mut _);
299+
300+
unsafe extern "system" fn load($($argname: $argtype),*) $(-> $rettype)? {
301+
let func = load_from_module();
302+
(func.unwrap())($($argname),*)
303+
}
304+
305+
fn load_from_module() -> Option<F> {
306+
unsafe {
307+
static SYMBOL_NAME: &CStr = ansi_str!(sym $symbol);
308+
309+
let in_unicows = if $unicows {
310+
Module::new(UNICOWS).and_then(|m| m.proc_address(SYMBOL_NAME))
311+
} else {
312+
None
313+
};
314+
315+
let f = in_unicows.or_else(|| {
316+
if $load {
317+
Module::new($name)
318+
} else {
319+
Module::load($name)
320+
}.and_then(|m| m.proc_address(SYMBOL_NAME))
321+
});
322+
323+
if let Some(f) = f {
324+
PTR.store(f.as_ptr(), Ordering::Relaxed);
325+
Some(mem::transmute(f))
326+
} else {
327+
PTR.store(crate::ptr::null_mut(), Ordering::Relaxed);
328+
None
329+
}
330+
}
331+
}
332+
333+
#[allow(dead_code)]
334+
pub fn option() -> Option<F> {
335+
unsafe {
336+
let ptr = PTR.load(Ordering::Relaxed);
337+
if ptr == load as *mut _ {
338+
load_from_module()
339+
} else {
340+
Some(mem::transmute(ptr))
341+
}
342+
}
343+
}
344+
345+
#[inline(always)]
346+
pub unsafe fn call($($argname: $argtype),*) $(-> $rettype)? {
347+
let func: F = mem::transmute(PTR.load(Ordering::Relaxed));
348+
func($($argname),*)
349+
}
350+
}
351+
$(#[$meta])*
352+
$vis use $symbol::call as $symbol;
353+
)*
354+
}
355+
}
356+
357+
macro_rules! static_load {
358+
(
359+
$library:expr,
360+
[$($symbol:ident),* $(,)?]
361+
) => {
362+
$(
363+
let $symbol = $library.proc_address(ansi_str!(sym $symbol))?;
364+
)*
365+
$(
366+
c::$symbol::PTR.store($symbol.as_ptr(), Ordering::Relaxed);
367+
)*
368+
}
369+
}
370+
225371
/// Load all needed functions from "api-ms-win-core-synch-l1-2-0".
226372
pub(super) fn load_synch_functions() {
227373
fn try_load() -> Option<()> {
228374
const MODULE_NAME: &CStr = c"api-ms-win-core-synch-l1-2-0";
229-
const WAIT_ON_ADDRESS: &CStr = c"WaitOnAddress";
230-
const WAKE_BY_ADDRESS_SINGLE: &CStr = c"WakeByAddressSingle";
231375

232376
// Try loading the library and all the required functions.
233377
// If any step fails, then they all fail.
234378
let library = unsafe { Module::new(MODULE_NAME) }?;
235-
let wait_on_address = library.proc_address(WAIT_ON_ADDRESS)?;
236-
let wake_by_address_single = library.proc_address(WAKE_BY_ADDRESS_SINGLE)?;
379+
static_load!(library, [WaitOnAddress, WakeByAddressSingle]);
237380

238-
c::WaitOnAddress::PTR.store(wait_on_address.as_ptr(), Ordering::Relaxed);
239-
c::WakeByAddressSingle::PTR.store(wake_by_address_single.as_ptr(), Ordering::Relaxed);
240381
Some(())
241382
}
242383

0 commit comments

Comments
 (0)