@@ -277,7 +277,7 @@ def get_class_that_defined_method(f):
277
277
return out
278
278
279
279
@classmethod
280
- def func_name (cls , fn ):
280
+ def get_func_name (cls , fn ):
281
281
# produces a name like torchrl.module.Class.method or torchrl.module.function
282
282
first = str (fn ).split ("." )[0 ][len ("<function " ) :]
283
283
last = str (fn ).split ("." )[1 :]
@@ -300,10 +300,10 @@ def _get_cls(self, fn):
300
300
301
301
def module_set (self ):
302
302
"""Sets the function in its module, if it exists already."""
303
- prev_setter = type (self )._implementations .get (self .func_name (self .fn ), None )
303
+ prev_setter = type (self )._implementations .get (self .get_func_name (self .fn ), None )
304
304
if prev_setter is not None :
305
305
prev_setter .do_set = False
306
- type(self )._implementations [self .func_name (self .fn )] = self
306
+ type(self )._implementations [self .get_func_name (self .fn )] = self
307
307
cls = self .get_class_that_defined_method (self .fn )
308
308
if cls is not None :
309
309
if cls .__class__ .__name__ == "function" :
@@ -329,11 +329,32 @@ def import_module(cls, module_name: Union[Callable, str]) -> str:
329
329
module = module_name ()
330
330
return module .__version__
331
331
332
+ _lazy_impl = collections .defaultdict (list )
333
+
334
+ def _delazify (self , func_name ):
335
+ for local_call in implement_for ._lazy_impl [func_name ]:
336
+ out = local_call ()
337
+ return out
338
+
332
339
def __call__ (self , fn ):
340
+ # function names are unique
341
+ self .func_name = self .get_func_name (fn )
333
342
self .fn = fn
343
+ implement_for ._lazy_impl [self .func_name ].append (self ._call )
344
+
345
+ @wraps (fn )
346
+ def _lazy_call_fn (* args , ** kwargs ):
347
+ # first time we call the function, we also do the replacement.
348
+ # This will cause the imports to occur only during the first call to fn
349
+ return self ._delazify (self .func_name )(* args , ** kwargs )
350
+
351
+ return _lazy_call_fn
352
+
353
+ def _call (self ):
334
354
335
355
# If the module is missing replace the function with the mock.
336
- func_name = self .func_name (self .fn )
356
+ fn = self .fn
357
+ func_name = self .func_name
337
358
implementations = implement_for ._implementations
338
359
339
360
@wraps (fn )
0 commit comments