Skip to content

Commit 9e3c4df

Browse files
author
Vincent Moens
committed
[Feature] Allow null-dim inputs in vLLMWrapper and TransformersWrapper
ghstack-source-id: f9985f8 Pull Request resolved: #2899
1 parent daa67cb commit 9e3c4df

File tree

16 files changed

+183
-106
lines changed

16 files changed

+183
-106
lines changed

.github/unittest/linux/scripts/run_all.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ echo "installing gymnasium"
9797
if [[ "$PYTHON_VERSION" == "3.12" ]]; then
9898
pip3 install ale-py
9999
pip3 install sympy
100-
pip3 install "gymnasium[accept-rom-license,mujoco]>=1.1" mo-gymnasium[mujoco]
100+
pip3 install "gymnasium[mujoco]>=1.1" mo-gymnasium[mujoco]
101101
else
102-
pip3 install "gymnasium[atari,accept-rom-license,mujoco]>=1.1" mo-gymnasium[mujoco]
102+
pip3 install "gymnasium[atari,mujoco]>=1.1" mo-gymnasium[mujoco]
103103
fi
104104
pip3 install "mujoco" -U
105105

.github/unittest/linux_distributed/scripts/setup_env.sh

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -96,32 +96,8 @@ if [[ $OSTYPE != 'darwin'* ]]; then
9696
# install ale-py: manylinux names are broken for CentOS so we need to manually download and
9797
# rename them
9898
PY_VERSION=$(python --version)
99-
echo "installing ale-py for ${PY_PY_VERSION}"
100-
if [[ $PY_VERSION == *"3.7"* ]]; then
101-
wget https://files.pythonhosted.org/packages/ab/fd/6615982d9460df7f476cad265af1378057eee9daaa8e0026de4cedbaffbd/ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
102-
pip install ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
103-
rm ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
104-
elif [[ $PY_VERSION == *"3.8"* ]]; then
105-
wget https://files.pythonhosted.org/packages/0f/8a/feed20571a697588bc4bfef05d6a487429c84f31406a52f8af295a0346a2/ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
106-
pip install ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
107-
rm ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
108-
elif [[ $PY_VERSION == *"3.9"* ]]; then
109-
wget https://files.pythonhosted.org/packages/a0/98/4316c1cedd9934f9a91b6e27a9be126043b4445594b40cfa391c8de2e5e8/ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
110-
pip install ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
111-
rm ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
112-
elif [[ $PY_VERSION == *"3.10"* ]]; then
113-
wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
114-
mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
115-
pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
116-
rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
117-
elif [[ $PY_VERSION == *"3.11"* ]]; then
118-
wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
119-
mv ale_py-0.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
120-
pip install ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
121-
rm ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
122-
fi
12399
echo "installing gymnasium"
124-
pip install "gymnasium[atari,accept-rom-license]>=1.1"
100+
pip install "gymnasium[atari]>=1.1"
125101
else
126-
pip install "gymnasium[atari,accept-rom-license]>=1.1"
102+
pip install "gymnasium[atari]>=1.1"
127103
fi

.github/unittest/linux_libs/scripts_gym/batch_scripts.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ do
126126
conda activate ./cloned_env
127127

128128
echo "Testing gym version: ${GYM_VERSION}"
129-
pip3 install 'gymnasium[atari,accept-rom-license,ale-py]'==$GYM_VERSION
129+
pip3 install 'gymnasium[atari,ale-py]'==$GYM_VERSION
130130

131131
$DIR/run_test.sh
132132

@@ -140,7 +140,7 @@ conda deactivate
140140
conda create --prefix ./cloned_env --clone ./env -y
141141
conda activate ./cloned_env
142142

143-
pip3 install 'gymnasium[accept-rom-license,ale-py,atari]>=1.1.0' mo-gymnasium gymnasium-robotics -U
143+
pip3 install 'gymnasium[ale-py,atari]>=1.1.0' mo-gymnasium gymnasium-robotics -U
144144

145145
$DIR/run_test.sh
146146

@@ -155,7 +155,7 @@ conda deactivate
155155
conda create --prefix ./cloned_env --clone ./env -y
156156
conda activate ./cloned_env
157157

158-
pip3 install 'gymnasium[accept-rom-license,ale-py,atari]>=1.1.0' mo-gymnasium gymnasium-robotics -U
158+
pip3 install 'gymnasium[ale-py,atari]>=1.1.0' mo-gymnasium gymnasium-robotics -U
159159

160160
$DIR/run_test.sh
161161

.github/unittest/linux_libs/scripts_llm/run_test.sh

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ ln -s /usr/bin/swig3.0 /usr/bin/swig
1010

1111
export PYTORCH_TEST_WITH_SLOW='1'
1212
export LAZY_LEGACY_OP=False
13+
14+
# to solve RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
15+
export VLLM_WORKER_MULTIPROC_METHOD=spawn
1316
python -m torch.utils.collect_env
1417
# Avoid error: "fatal: unsafe repository"
1518
git config --global --add safe.directory '*'
@@ -22,11 +25,11 @@ conda deactivate && conda activate ./env
2225

2326
python -c "import transformers, datasets"
2427

25-
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips
28+
pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips
2629

27-
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_actors.py -k llm --instafail -v --durations 200 --capture no --error-for-skips --runslow
30+
pytest test/test_actors.py test/test_collector.py -k llm --instafail -v --durations 200 --capture no --error-for-skips --runslow
2831

29-
python .github/unittest/helpers/coverage_run_parallel.py examples/rlhf/train_rlhf.py \
32+
pytest examples/rlhf/train_rlhf.py \
3033
sys.device=cuda:0 sys.ref_device=cuda:0 \
3134
model.name_or_path=gpt2 train.max_epochs=2 \
3235
data.batch_size=2 train.ppo.ppo_batch_size=2 \

.github/unittest/linux_sota/scripts/run_all.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ python -c """import gym;import d4rl"""
112112

113113
# install ale-py: manylinux names are broken for CentOS so we need to manually download and
114114
# rename them
115-
pip install "gymnasium[atari,accept-rom-license]>=1.1.0"
115+
pip install "gymnasium[atari]>=1.1.0"
116116

117117
# ============================================================================================ #
118118
# ================================ PyTorch & TorchRL ========================================= #

sota-check/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ export MUJOCO_GL=egl
2626
conda create -n rl-sota-bench python=3.10 -y
2727
conda install anaconda::libglu -y
2828
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121
29-
pip3 install "gymnasium[accept-rom-license,atari,mujoco]" vmas tqdm wandb pygame "moviepy<2.0.0" imageio submitit hydra-core transformers
29+
pip3 install "gymnasium[atari,mujoco]" vmas tqdm wandb pygame "moviepy<2.0.0" imageio submitit hydra-core transformers
3030

3131
cd /path/to/tensordict
3232
python setup.py develop

test/_utils_internal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _set_gym_environments(): # noqa: F811
156156
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED
157157

158158
_CARTPOLE_VERSIONED = "CartPole-v1"
159-
_HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
159+
_HALFCHEETAH_VERSIONED = "HalfCheetah-v5"
160160
_PENDULUM_VERSIONED = "Pendulum-v1"
161161
_PONG_VERSIONED = "ALE/Pong-v5"
162162
_BREAKOUT_VERSIONED = "ALE/Breakout-v5"

test/test_env.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
ParallelEnv,
5757
PendulumEnv,
5858
SerialEnv,
59+
set_gym_backend,
5960
TicTacToeEnv,
6061
)
6162
from torchrl.envs.batched_envs import _stackable
@@ -2511,6 +2512,7 @@ def test_info_dict_reader(self, device, seed=0):
25112512
import gymnasium as gym
25122513
except ModuleNotFoundError:
25132514
import gym
2515+
set_gym_backend(gym).set()
25142516

25152517
env = GymWrapper(gym.make(HALFCHEETAH_VERSIONED()), device=device)
25162518
env.set_info_dict_reader(
@@ -2542,7 +2544,7 @@ def test_info_dict_reader(self, device, seed=0):
25422544
),
25432545
[Unbounded((), dtype=torch.float64)],
25442546
):
2545-
env2 = GymWrapper(gym.make("HalfCheetah-v4"))
2547+
env2 = GymWrapper(gym.make("HalfCheetah-v5"))
25462548
env2.set_info_dict_reader(
25472549
default_info_dict_reader(["x_position"], spec=spec)
25482550
)

test/test_libs.py

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def maybe_init_minigrid():
190190
minigrid.register_minigrid_envs()
191191

192192

193+
@implement_for("gym")
193194
def get_gym_pixel_wrapper():
194195
try:
195196
# works whenever gym_version > version.parse("0.19")
@@ -203,6 +204,29 @@ def get_gym_pixel_wrapper():
203204
return PixelObservationWrapper
204205

205206

207+
@implement_for("gymnasium", None, "1.1.0")
208+
def get_gym_pixel_wrapper(): # noqa: F811
209+
try:
210+
# works whenever gym_version > version.parse("0.19")
211+
PixelObservationWrapper = gym_backend(
212+
"wrappers.pixel_observation"
213+
).PixelObservationWrapper
214+
except Exception:
215+
from torchrl.envs.libs.utils import (
216+
GymPixelObservationWrapper as PixelObservationWrapper,
217+
)
218+
return PixelObservationWrapper
219+
220+
221+
@implement_for("gymnasium", "1.1.0")
222+
def get_gym_pixel_wrapper(): # noqa: F811
223+
# works whenever gym_version > version.parse("0.19")
224+
PixelObservationWrapper = lambda *args, pixels_only=False, **kwargs: gym_backend(
225+
"wrappers"
226+
).AddRenderObservation(*args, render_only=pixels_only, **kwargs)
227+
return PixelObservationWrapper
228+
229+
206230
if _has_gym:
207231
try:
208232
from gymnasium import __version__ as gym_version
@@ -1030,7 +1054,12 @@ def test_one_hot_and_categorical(self): # noqa: F811
10301054
)
10311055
@pytest.mark.flaky(reruns=5, reruns_delay=1)
10321056
def test_vecenvs_wrapper(self, envname):
1033-
self._test_vecenvs_wrapper(envname)
1057+
import gymnasium
1058+
1059+
with set_gym_backend("gymnasium"):
1060+
self._test_vecenvs_wrapper(
1061+
envname, kwargs={"reset_mode": gymnasium.vector.AutoresetMode.SAME_STEP}
1062+
)
10341063

10351064
@implement_for("gymnasium", None, "1.0.0")
10361065
@pytest.mark.parametrize(
@@ -1040,22 +1069,25 @@ def test_vecenvs_wrapper(self, envname):
10401069
)
10411070
@pytest.mark.flaky(reruns=5, reruns_delay=1)
10421071
def test_vecenvs_wrapper(self, envname): # noqa
1043-
self._test_vecenvs_wrapper(envname)
1072+
with set_gym_backend("gymnasium"):
1073+
self._test_vecenvs_wrapper(envname)
10441074

1045-
def _test_vecenvs_wrapper(self, envname):
1075+
def _test_vecenvs_wrapper(self, envname, kwargs=None):
10461076
import gymnasium
10471077

1078+
if kwargs is None:
1079+
kwargs = {}
10481080
# we can't use parametrize with implement_for
10491081
env = GymWrapper(
10501082
gymnasium.vector.SyncVectorEnv(
1051-
2 * [lambda envname=envname: gymnasium.make(envname)]
1083+
2 * [lambda envname=envname: gymnasium.make(envname)], **kwargs
10521084
)
10531085
)
10541086
assert env.batch_size == torch.Size([2])
10551087
check_env_specs(env)
10561088
env = GymWrapper(
10571089
gymnasium.vector.AsyncVectorEnv(
1058-
2 * [lambda envname=envname: gymnasium.make(envname)]
1090+
2 * [lambda envname=envname: gymnasium.make(envname)], **kwargs
10591091
)
10601092
)
10611093
assert env.batch_size == torch.Size([2])
@@ -1113,25 +1145,26 @@ def _test_vecenvs_env(self, envname):
11131145
)
11141146
@pytest.mark.flaky(reruns=5, reruns_delay=1)
11151147
def test_vecenvs_wrapper(self, envname): # noqa: F811
1116-
gym = gym_backend()
1117-
# we can't use parametrize with implement_for
1118-
for envname in ["CartPole-v1", "HalfCheetah-v4"]:
1119-
env = GymWrapper(
1120-
gym.vector.SyncVectorEnv(
1121-
2 * [lambda envname=envname: gym.make(envname)]
1148+
with set_gym_backend("gym"):
1149+
gym = gym_backend()
1150+
# we can't use parametrize with implement_for
1151+
for envname in ["CartPole-v1", "HalfCheetah-v4"]:
1152+
env = GymWrapper(
1153+
gym.vector.SyncVectorEnv(
1154+
2 * [lambda envname=envname: gym.make(envname)]
1155+
)
11221156
)
1123-
)
1124-
assert env.batch_size == torch.Size([2])
1125-
check_env_specs(env)
1126-
env = GymWrapper(
1127-
gym.vector.AsyncVectorEnv(
1128-
2 * [lambda envname=envname: gym.make(envname)]
1157+
assert env.batch_size == torch.Size([2])
1158+
check_env_specs(env)
1159+
env = GymWrapper(
1160+
gym.vector.AsyncVectorEnv(
1161+
2 * [lambda envname=envname: gym.make(envname)]
1162+
)
11291163
)
1130-
)
1131-
assert env.batch_size == torch.Size([2])
1132-
check_env_specs(env)
1133-
env.close()
1134-
del env
1164+
assert env.batch_size == torch.Size([2])
1165+
check_env_specs(env)
1166+
env.close()
1167+
del env
11351168

11361169
@implement_for("gym", "0.18")
11371170
@pytest.mark.parametrize(
@@ -1150,17 +1183,17 @@ def test_vecenvs_env(self, envname): # noqa: F811
11501183
env = GymEnv(envname, num_envs=2, from_pixels=False)
11511184
env.set_seed(0)
11521185
assert env.get_library_name(env._env) == "gym"
1153-
# rollouts can be executed without decorator
1154-
check_env_specs(env)
1155-
rollout = env.rollout(100, break_when_any_done=False)
1156-
for obs_key in env.observation_spec.keys(True, True):
1157-
rollout_consistency_assertion(
1158-
rollout,
1159-
done_key="done",
1160-
observation_key=obs_key,
1161-
done_strict="CartPole" in envname,
1162-
)
1163-
env.close()
1186+
# rollouts can be executed without decorator
1187+
check_env_specs(env)
1188+
rollout = env.rollout(100, break_when_any_done=False)
1189+
for obs_key in env.observation_spec.keys(True, True):
1190+
rollout_consistency_assertion(
1191+
rollout,
1192+
done_key="done",
1193+
observation_key=obs_key,
1194+
done_strict="CartPole" in envname,
1195+
)
1196+
env.close()
11641197
del env
11651198
if envname != "CartPole-v1":
11661199
with set_gym_backend("gym"):
@@ -1469,7 +1502,7 @@ def reset(
14691502
{},
14701503
)
14711504

1472-
yield CountingEnvRandomReset
1505+
return CountingEnvRandomReset
14731506

14741507
@implement_for("gym")
14751508
def test_gymnasium_autoreset(self, venv):
@@ -1484,6 +1517,8 @@ def test_gymnasium_autoreset(self, venv): # noqa
14841517
def test_gymnasium_autoreset(self, venv): # noqa
14851518
import gymnasium as gym
14861519

1520+
set_gym_backend("gymnasium").set()
1521+
14871522
counting_env = self.counting_env()
14881523
if venv == "sync":
14891524
venv = gym.vector.SyncVectorEnv

torchrl/_utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,10 +423,17 @@ def module_set(self):
423423
else:
424424
# class not yet defined
425425
return
426+
try:
427+
delattr(cls, self.fn.__name__)
428+
except AttributeError:
429+
pass
430+
431+
name = self.fn.__name__
426432
if self.class_method:
427-
setattr(cls, self.fn.__name__, classmethod(self.fn))
433+
fn = classmethod(self.fn)
428434
else:
429-
setattr(cls, self.fn.__name__, self.fn)
435+
fn = self.fn
436+
setattr(cls, name, fn)
430437

431438
@classmethod
432439
def import_module(cls, module_name: Callable | str) -> str:
@@ -543,7 +550,7 @@ def __repr__(self):
543550
return (
544551
f"{self.__class__.__name__}("
545552
f"module_name={self.module_name}({self.from_version, self.to_version}), "
546-
f"fn_name={self.fn.__name__}, cls={self._get_cls(self.fn)}, is_set={self.do_set})"
553+
f"fn_name={self.fn.__name__}, cls={self._get_cls(self.fn)})"
547554
)
548555

549556

0 commit comments

Comments
 (0)