Skip to content

Commit a67b9fb

Browse files
author
Vincent Moens
authored
[Feature] Lazy imports for implement_for during torchrl import (#1646)
1 parent f8788b1 commit a67b9fb

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

test/test_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,11 @@ def test_implement_for_missing_version():
151151
def test_implement_for_reset():
152152
assert implement_for_test_functions.select_correct_version() == "0.3+"
153153
_impl = copy(implement_for._implementations)
154-
name = implement_for.func_name(implement_for_test_functions.select_correct_version)
154+
name = implement_for.get_func_name(
155+
implement_for_test_functions.select_correct_version
156+
)
155157
for setter in implement_for._setters:
156-
if implement_for.func_name(setter.fn) == name and setter.fn() != "0.3+":
158+
if implement_for.get_func_name(setter.fn) == name and setter.fn() != "0.3+":
157159
setter.module_set()
158160
assert implement_for_test_functions.select_correct_version() != "0.3+"
159161
implement_for.reset(_impl)

torchrl/_utils.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def get_class_that_defined_method(f):
277277
return out
278278

279279
@classmethod
280-
def func_name(cls, fn):
280+
def get_func_name(cls, fn):
281281
# produces a name like torchrl.module.Class.method or torchrl.module.function
282282
first = str(fn).split(".")[0][len("<function ") :]
283283
last = str(fn).split(".")[1:]
@@ -300,10 +300,10 @@ def _get_cls(self, fn):
300300

301301
def module_set(self):
302302
"""Sets the function in its module, if it exists already."""
303-
prev_setter = type(self)._implementations.get(self.func_name(self.fn), None)
303+
prev_setter = type(self)._implementations.get(self.get_func_name(self.fn), None)
304304
if prev_setter is not None:
305305
prev_setter.do_set = False
306-
type(self)._implementations[self.func_name(self.fn)] = self
306+
type(self)._implementations[self.get_func_name(self.fn)] = self
307307
cls = self.get_class_that_defined_method(self.fn)
308308
if cls is not None:
309309
if cls.__class__.__name__ == "function":
@@ -329,11 +329,32 @@ def import_module(cls, module_name: Union[Callable, str]) -> str:
329329
module = module_name()
330330
return module.__version__
331331

332+
_lazy_impl = collections.defaultdict(list)
333+
334+
def _delazify(self, func_name):
335+
for local_call in implement_for._lazy_impl[func_name]:
336+
out = local_call()
337+
return out
338+
332339
def __call__(self, fn):
340+
# function names are unique
341+
self.func_name = self.get_func_name(fn)
333342
self.fn = fn
343+
implement_for._lazy_impl[self.func_name].append(self._call)
344+
345+
@wraps(fn)
346+
def _lazy_call_fn(*args, **kwargs):
347+
# first time we call the function, we also do the replacement.
348+
# This will cause the imports to occur only during the first call to fn
349+
return self._delazify(self.func_name)(*args, **kwargs)
350+
351+
return _lazy_call_fn
352+
353+
def _call(self):
334354

335355
# If the module is missing replace the function with the mock.
336-
func_name = self.func_name(self.fn)
356+
fn = self.fn
357+
func_name = self.func_name
337358
implementations = implement_for._implementations
338359

339360
@wraps(fn)

0 commit comments

Comments
 (0)