@@ -112,12 +112,25 @@ impl Module {
112
112
///
113
113
/// This should only be use for modules that exist for the lifetime of std
114
114
/// (e.g. kernel32 and ntdll).
115
- pub unsafe fn new ( name : & CStr ) -> Option < Self > {
115
+ pub unsafe fn get ( name : & CStr ) -> Option < Self > {
116
116
// SAFETY: A CStr is always null terminated.
117
117
let module = c:: GetModuleHandleA ( name. as_ptr ( ) . cast :: < u8 > ( ) ) ;
118
118
NonNull :: new ( module) . map ( Self )
119
119
}
120
120
121
+ /// Try to get a handle to a loaded module. If the module was not loaded, then load
122
+ /// the module.
123
+ ///
124
+ /// # SAFETY
125
+ ///
126
+ /// This should only be used for modules that exist for the lifetime of std
127
+ /// that are not pre-loaded by Windows (e.g. advapi32 and bcrypt)
128
+ pub unsafe fn load ( name : & CStr ) -> Option < Self > {
129
+ // SAFETY: A CStr is always null terminated.
130
+ let module = c:: LoadLibraryA ( name. as_ptr ( ) . cast :: < u8 > ( ) ) ;
131
+ NonNull :: new ( module) . map ( Self )
132
+ }
133
+
121
134
// Try to get the address of a function.
122
135
pub fn proc_address ( self , name : & CStr ) -> Option < NonNull < c_void > > {
123
136
unsafe {
@@ -135,7 +148,8 @@ pub(crate) const UNICOWS_MODULE_NAME: &CStr = ansi_str!("unicows");
135
148
136
149
/// Load a function or use a fallback implementation if that fails.
137
150
macro_rules! compat_fn_with_fallback {
138
- ( pub static $module: ident: & CStr = $name: expr; $(
151
+ ( pub static $module: ident: & CStr = $name: expr;
152
+ const LOAD : bool = $load: expr; $(
139
153
$( #[ $meta: meta] ) *
140
154
$vis: vis fn $symbol: ident( $( $argname: ident: $argtype: ty) ,* ) -> $rettype: ty $fallback_body: block
141
155
) * ) => (
@@ -160,14 +174,22 @@ macro_rules! compat_fn_with_fallback {
160
174
static PTR : AtomicPtr <c_void> = AtomicPtr :: new( load as * mut _) ;
161
175
162
176
unsafe extern "system" fn load( $( $argname: $argtype) ,* ) -> $rettype {
163
- let func = load_from_module( Module :: new ( $module) ) ;
177
+ let func = load_from_module( load_module ( $module) ) ;
164
178
func( $( $argname) ,* )
165
179
}
166
180
181
+ fn load_module( ) -> Option <Module > {
182
+ if $load {
183
+ Module :: load( $module)
184
+ } else {
185
+ Module :: get( $module)
186
+ }
187
+ }
188
+
167
189
fn load_from_module( module: Option <Module >) -> F {
168
190
unsafe {
169
191
static SYMBOL_NAME : & CStr = ansi_str!( sym $symbol) ;
170
- if let Some ( f) = Module :: new ( $crate:: sys:: compat:: UNICOWS_MODULE_NAME )
192
+ if let Some ( f) = Module :: get ( $crate:: sys:: compat:: UNICOWS_MODULE_NAME )
171
193
. and_then( |m| m. proc_address( SYMBOL_NAME ) )
172
194
. or_else( || module. and_then( |m| m. proc_address( SYMBOL_NAME ) ) )
173
195
{
@@ -196,7 +218,7 @@ macro_rules! compat_fn_with_fallback {
196
218
}
197
219
198
220
// Otherwise, the function pointer should be resolved
199
- let _ = load_from_module( Module :: new ( $module) ) ;
221
+ let _ = load_from_module( load_module ( $module) ) ;
200
222
201
223
// After resolution, the function pointer will only point to `fallback` if
202
224
// the target function was not found
@@ -212,6 +234,19 @@ macro_rules! compat_fn_with_fallback {
212
234
$( #[ $meta] ) *
213
235
$vis use $symbol:: call as $symbol;
214
236
) * )
237
+ ( pub static $module: ident: & CStr = $name: expr; $(
238
+ $( #[ $meta: meta] ) *
239
+ $vis: vis fn $symbol: ident( $( $argname: ident: $argtype: ty) ,* ) -> $rettype: ty $fallback_body: block
240
+ ) * ) => (
241
+ compat_fn_with_fallback! {
242
+ pub static $module: & CStr = $name;
243
+ const LOAD : bool = false ;
244
+ $(
245
+ $( #[ $meta] ) *
246
+ $vis fn $symbol( $( $argname: $argtype) ,* ) -> $rettype $fallback_body
247
+ ) *
248
+ }
249
+ )
215
250
}
216
251
217
252
/// Optionally loaded functions.
@@ -257,7 +292,7 @@ pub(super) fn load_synch_functions() {
257
292
258
293
// Try loading the library and all the required functions.
259
294
// If any step fails, then they all fail.
260
- let library = unsafe { Module :: new ( MODULE_NAME ) } ?;
295
+ let library = unsafe { Module :: get ( MODULE_NAME ) } ?;
261
296
let wait_on_address = library. proc_address ( WAIT_ON_ADDRESS ) ?;
262
297
let wake_by_address_single = library. proc_address ( WAKE_BY_ADDRESS_SINGLE ) ?;
263
298
0 commit comments