Skip to content

Commit e96fd37

Browse files
ordinskiyvmoens
andauthored
[Feature] Added implement_for decorator (#618)
* [Feature] Added `implement_for` decorator (#) * `from_version` can be open (`None`). * Changed behaviour in case of missing module/version as was discussed. * Improved tests and exposed `implement_for` in the documentation Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent 530dac3 commit e96fd37

File tree

5 files changed

+144
-2
lines changed

5 files changed

+144
-2
lines changed

docs/source/reference/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ API Reference
1010
modules
1111
objectives
1212
trainers
13+
utils

docs/source/reference/utils.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
.. currentmodule:: torchrl._utils
2+
3+
torchrl._utils package
4+
====================
5+
6+
Set of utility methods that are used internally by the library.
7+
8+
9+
.. autosummary::
10+
:toctree: generated/
11+
:template: rl_template.rst
12+
13+
implement_for

test/_utils_internal.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
from torchrl.envs import EnvBase
1616

1717

18+
# Specified for test_utils.py
19+
__version__ = "0.3"
20+
21+
1822
def get_relative_path(curr_file, *path_components):
1923
return os.path.join(os.path.dirname(curr_file), *path_components)
2024

test/test_utils.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
16
import os
27

38
import pytest
4-
from torchrl._utils import get_binary_env_var
9+
from torchrl._utils import get_binary_env_var, implement_for
510

611

712
@pytest.mark.parametrize("value", ["True", "1", "true"])
@@ -60,3 +65,59 @@ def test_get_binary_env_var_wrong_value():
6065
finally:
6166
if key in os.environ:
6267
del os.environ[key]
68+
69+
70+
class implement_for_test_functions:
71+
"""
72+
Groups functions that are used in tests for `implement_for` decorator.
73+
"""
74+
75+
@staticmethod
76+
@implement_for("_utils_internal", "0.3")
77+
def select_correct_version():
78+
"""To test from+ range and that this function is correctly selected as the implementation."""
79+
return "0.3+"
80+
81+
@staticmethod
82+
@implement_for("_utils_internal", "0.2", "0.3")
83+
def select_correct_version(): # noqa: F811
84+
"""To test that right bound is not included."""
85+
return "0.2-0.3"
86+
87+
@staticmethod
88+
@implement_for("_utils_internal", "0.1", "0.2")
89+
def select_correct_version(): # noqa: F811
90+
"""To test that function with missing from-to range is ignored."""
91+
return "0.1-0.2"
92+
93+
@staticmethod
94+
@implement_for("missing_module")
95+
def missing_module():
96+
"""To test that calling decorated function with missing module raises an exception."""
97+
return "missing"
98+
99+
@staticmethod
100+
@implement_for("_utils_internal", None, "0.3")
101+
def missing_version():
102+
return "0-0.3"
103+
104+
@staticmethod
105+
@implement_for("_utils_internal", "0.4")
106+
def missing_version(): # noqa: F811
107+
return "0.4+"
108+
109+
110+
def test_implement_for():
111+
assert implement_for_test_functions.select_correct_version() == "0.3+"
112+
113+
114+
def test_implement_for_missing_module():
115+
msg = "Supported version of 'missing_module' has not been found."
116+
with pytest.raises(ModuleNotFoundError, match=msg):
117+
implement_for_test_functions.missing_module()
118+
119+
120+
def test_implement_for_missing_version():
121+
msg = "Supported version of '_utils_internal' has not been found."
122+
with pytest.raises(ModuleNotFoundError, match=msg):
123+
implement_for_test_functions.missing_version()

torchrl/_utils.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import math
33
import os
44
import time
5+
from functools import wraps
6+
from importlib import import_module
57

68
import numpy as np
79

@@ -15,6 +17,7 @@ def __init__(self, name):
1517
self.name = name
1618

1719
def __call__(self, fn):
20+
@wraps(fn)
1821
def decorated_fn(*args, **kwargs):
1922
with self:
2023
out = fn(*args, **kwargs)
@@ -122,7 +125,7 @@ def prod(sequence):
122125

123126

124127
def get_binary_env_var(key):
125-
"""Parses and returns the binary enironment variable value.
128+
"""Parses and returns the binary environment variable value.
126129
127130
If not present in environment, it is considered `False`.
128131
@@ -176,3 +179,63 @@ def __repr__(self):
176179

177180

178181
_CKPT_BACKEND = _Dynamic_CKPT_BACKEND()
182+
183+
184+
class implement_for:
185+
"""A version decorator that checks the version in the environment and implements a function with the fitting one.
186+
187+
If specified module is missing or there is no fitting implementation, call of the decorated function
188+
will lead to the explicit error.
189+
In case of intersected ranges, first fitting implementation is used.
190+
191+
Args:
192+
module_name: version is checked for the module with this name (e.g. "gym").
193+
from_version: version from which implementation is compatible. Can be open (None).
194+
to_version: version from which implementation is no longer compatible. Can be open (None).
195+
196+
Examples:
197+
>>> @implement_for(“gym”, “0.13”, “0.14”)
198+
>>> def fun(self, x):
199+
200+
This indicates that the function is compatible with gym 0.13+, but doesn't with gym 0.14+.
201+
"""
202+
203+
# Stores pointers to fitting implementations: dict[func_name] = func_pointer
204+
_implementations = {}
205+
206+
def __init__(
207+
self, module_name: str, from_version: str = None, to_version: str = None
208+
):
209+
self.module_name = module_name
210+
self.from_version = from_version
211+
self.to_version = to_version
212+
213+
def __call__(self, fn):
214+
@wraps(fn)
215+
def unsupported():
216+
raise ModuleNotFoundError(
217+
f"Supported version of '{self.module_name}' has not been found."
218+
)
219+
220+
# If the module is missing replace the function with the mock.
221+
try:
222+
module = import_module(self.module_name)
223+
except ModuleNotFoundError:
224+
return unsupported
225+
226+
func_name = f"{fn.__module__}.{fn.__name__}"
227+
implementations = implement_for._implementations
228+
229+
# Return fitting implementation if it was encountered before.
230+
if func_name in implementations:
231+
return implementations[func_name]
232+
233+
version = module.__version__
234+
235+
if (self.from_version is None or version >= self.from_version) and (
236+
self.to_version is None or version < self.to_version
237+
):
238+
implementations[func_name] = fn
239+
return fn
240+
241+
return unsupported

0 commit comments

Comments
 (0)