Skip to content

Commit 700c2ee

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents e7173db + 241d36b commit 700c2ee

File tree

12 files changed

+430
-98
lines changed

12 files changed

+430
-98
lines changed

test/test_actors.py

Lines changed: 153 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,27 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import argparse
6+
import importlib.util
67
import os
78

89
import pytest
910
import torch
10-
1111
from tensordict import NonTensorStack, TensorDict
1212
from tensordict.nn import CompositeDistribution, TensorDictModule
1313
from tensordict.nn.distributions import NormalParamExtractor
1414

1515
from torch import distributions as dist, nn
1616
from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
17+
from torchrl.data.llm import LLMData
1718
from torchrl.data.llm.dataset import _has_transformers
18-
from torchrl.modules import from_hf_transformers, MLP, SafeModule, TanhDelta, TanhNormal
19+
from torchrl.modules import (
20+
from_hf_transformers,
21+
from_vllm,
22+
MLP,
23+
SafeModule,
24+
TanhDelta,
25+
TanhNormal,
26+
)
1927
from torchrl.modules.tensordict_module.actors import (
2028
_process_action_space_spec,
2129
ActorValueOperator,
@@ -37,6 +45,8 @@
3745
from _utils_internal import get_default_devices
3846
from mocking_classes import NestedCountingEnv
3947

48+
_has_vllm = importlib.util.find_spec("vllm") is not None
49+
4050

4151
@pytest.mark.parametrize(
4252
"log_prob_key",
@@ -908,6 +918,7 @@ def test_lmhead_actorvalueoperator(device):
908918

909919

910920
@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies")
921+
@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies")
911922
class TestTransformerActor:
912923
@pytest.mark.parametrize(
913924
"from_text, generate, tokens, attention_mask",
@@ -924,7 +935,6 @@ class TestTransformerActor:
924935
],
925936
)
926937
def test_from_hf_transformers(self, from_text, generate, tokens, attention_mask):
927-
from torchrl.data.llm import LLMData
928938
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel
929939

930940
tokenizer = AutoTokenizer.from_pretrained("gpt2")
@@ -934,26 +944,157 @@ def test_from_hf_transformers(self, from_text, generate, tokens, attention_mask)
934944
m = from_hf_transformers(
935945
model, tokenizer=tokenizer, from_text=from_text, generate=generate
936946
)
947+
self._run_check(m, tokens, attention_mask, generate, from_text, has_logits=True)
948+
949+
@pytest.mark.parametrize(
950+
"from_text, generate, tokens, attention_mask",
951+
[
952+
(True, True, None, None),
953+
(True, False, None, None),
954+
(
955+
False,
956+
True,
957+
torch.randint(1024, (1, 10)),
958+
torch.ones(1, 10, dtype=torch.int64),
959+
),
960+
(False, True, torch.randint(1024, (1, 10)), None),
961+
],
962+
)
963+
def test_from_vllm(self, from_text, generate, tokens, attention_mask):
964+
from vllm import LLM
965+
966+
model = LLM(model="facebook/opt-125m")
967+
m = from_vllm(model, from_text=from_text, generate=generate)
968+
self._run_check(
969+
m, tokens, attention_mask, generate, from_text, has_logits=False
970+
)
971+
972+
def _make_data(
973+
self,
974+
m,
975+
tokens,
976+
attention_mask,
977+
generate,
978+
from_text,
979+
has_logits,
980+
text_response=None,
981+
tokens_response=None,
982+
):
983+
lp_kwargs = {}
937984
if from_text:
938-
tdin = LLMData(text=NonTensorStack("a text"), batch_size=1)
985+
if not generate:
986+
text_response = (
987+
NonTensorStack(" and another text that follows")
988+
if text_response is None
989+
else text_response
990+
)
991+
if not isinstance(text_response, NonTensorStack):
992+
if isinstance(text_response, list):
993+
text_response = NonTensorStack(*text_response)
994+
else:
995+
text_response = NonTensorStack(text_response)
996+
lp_kwargs.update({"text_response": text_response})
997+
tdin = LLMData(text=NonTensorStack("a text"), **lp_kwargs, batch_size=1)
939998
else:
940-
tdin = LLMData(tokens=tokens, attention_mask=attention_mask, batch_size=1)
999+
if not generate:
1000+
if tokens_response is None:
1001+
shape_response = tokens.shape
1002+
shape_response = shape_response[:-1] + (shape_response[-1] * 2,)
1003+
tokens_response = torch.randint(1024, shape_response)
1004+
lp_kwargs.update({"tokens_response": tokens_response})
1005+
tdin = LLMData(
1006+
tokens=tokens, attention_mask=attention_mask, **lp_kwargs, batch_size=1
1007+
)
1008+
return tdin
1009+
1010+
def _run_check(self, m, tokens, attention_mask, generate, from_text, has_logits):
1011+
tdin = self._make_data(
1012+
m, tokens, attention_mask, generate, from_text, has_logits
1013+
)
1014+
if from_text and generate:
1015+
assert tdin.text_response is None
1016+
elif from_text and not generate:
1017+
assert tdin.text_response is not None
1018+
9411019
td = m(tdin)
9421020
assert td is tdin
9431021
assert isinstance(td, LLMData)
9441022
if from_text and generate:
9451023
assert td.text_response is not None
946-
else:
947-
assert td.text_response is None
948-
if attention_mask is not None or from_text:
949-
assert td.attention_mask is not None
1024+
if generate and (attention_mask is not None or from_text):
1025+
assert td.attention_mask is not None, (generate, generate, from_text)
9501026
else:
9511027
assert td.attention_mask is None
9521028
if not generate:
953-
assert td.text_response is None
954-
assert td.tokens_response is None
1029+
# logprobs are computed on text response of tokens_response
1030+
assert td.text_response is not None or td.tokens_response is not None
9551031
assert td.log_probs is not None
956-
assert td.logits is not None
1032+
if has_logits:
1033+
assert td.logits is not None
1034+
1035+
# Test the shapes
1036+
assert td.tokens_response is not None, (generate, has_logits, from_text)
1037+
1038+
# If from text and not generating, the tokens are not returned for now
1039+
if not (from_text and not generate):
1040+
assert td.tokens_response.shape[:-1] == td.tokens.shape[:-1]
1041+
# The convention is that the response only has new tokens
1042+
assert (
1043+
td.tokens_response[..., : td.tokens.shape[-1]]
1044+
!= td.tokens[..., : td.tokens_response.shape[-1]]
1045+
).any()
1046+
1047+
@pytest.mark.parametrize(
1048+
"from_text, tokens, attention_mask",
1049+
[
1050+
(True, None, None),
1051+
(
1052+
False,
1053+
torch.randint(1024, (1, 10)),
1054+
torch.ones(1, 10, dtype=torch.int64),
1055+
),
1056+
(False, torch.randint(1024, (1, 10)), None),
1057+
],
1058+
)
1059+
def test_from_vllm_logprobs(self, from_text, tokens, attention_mask):
1060+
from vllm import LLM
1061+
1062+
model = LLM(model="facebook/opt-125m")
1063+
m_generate = from_vllm(model, from_text=from_text, generate=True)
1064+
m_logprobs = from_vllm(model, from_text=from_text, generate=False)
1065+
self._check_lps(
1066+
m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False
1067+
)
1068+
1069+
def _check_lps(
1070+
self,
1071+
model_generate,
1072+
model_logprobs,
1073+
tokens,
1074+
attention_mask,
1075+
from_text,
1076+
has_logits,
1077+
):
1078+
# Checks that the log-probs gathered with generate=False equate those with generate=True
1079+
tdin_genetate = self._make_data(
1080+
model_generate, tokens, attention_mask, True, from_text, has_logits
1081+
)
1082+
td_generate = model_generate(tdin_genetate)
1083+
tdin_logprobs = self._make_data(
1084+
model_logprobs,
1085+
tokens,
1086+
attention_mask,
1087+
False,
1088+
from_text,
1089+
has_logits,
1090+
tokens_response=td_generate.tokens_response,
1091+
text_response=td_generate.text_response,
1092+
)
1093+
td_logprobs = model_logprobs(tdin_logprobs)
1094+
print(td_generate.log_probs / td_logprobs.log_probs)
1095+
torch.testing.assert_close(
1096+
td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2
1097+
)
9571098

9581099

9591100
if __name__ == "__main__":

test/test_env.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,6 +1692,34 @@ def test_parallel_env_device(
16921692
env_serial.close(raise_if_closed=False)
16931693
env0.close(raise_if_closed=False)
16941694

1695+
@pytest.mark.skipif(not _has_gym, reason="no gym")
1696+
@pytest.mark.parametrize("env_device", [None, "cpu"])
1697+
def test_parallel_env_device_vs_no_device(self, maybe_fork_ParallelEnv, env_device):
1698+
def make_env() -> GymEnv:
1699+
env = GymEnv(PENDULUM_VERSIONED(), device=env_device)
1700+
return env.append_transform(DoubleToFloat())
1701+
1702+
# Rollouts work with a regular env
1703+
parallel_env = maybe_fork_ParallelEnv(
1704+
num_workers=1, create_env_fn=make_env, device=None
1705+
)
1706+
parallel_env.reset()
1707+
parallel_env.set_seed(0)
1708+
torch.manual_seed(0)
1709+
1710+
parallel_rollout = parallel_env.rollout(max_steps=10)
1711+
1712+
# Rollout doesn't work with Parallelnv
1713+
parallel_env = maybe_fork_ParallelEnv(
1714+
num_workers=1, create_env_fn=make_env, device="cpu"
1715+
)
1716+
parallel_env.reset()
1717+
parallel_env.set_seed(0)
1718+
torch.manual_seed(0)
1719+
1720+
parallel_rollout_cpu = parallel_env.rollout(max_steps=10)
1721+
assert_allclose_td(parallel_rollout, parallel_rollout_cpu)
1722+
16951723
@pytest.mark.skipif(not _has_gym, reason="no gym")
16961724
@pytest.mark.flaky(reruns=3, reruns_delay=1)
16971725
@pytest.mark.parametrize(

test/test_storage_map.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,17 @@ def test_edges(self):
350350
edges_check = {(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)}
351351
assert edges == edges_check
352352

353+
def test_make_node(self):
354+
td = TensorDict({"obs": torch.tensor([0])})
355+
tree = Tree(node_data=td)
356+
assert tree.node_data is not None
357+
358+
tree = Tree.make_node(data=td)
359+
assert tree.node_data is not None
360+
361+
tree = Tree.make_node(td)
362+
assert tree.node_data is not None
363+
353364

354365
class TestMCTSForest:
355366
def dummy_rollouts(self) -> Tuple[TensorDict, ...]:

torchrl/_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import warnings
1919
from contextlib import nullcontext
2020
from copy import copy
21-
from distutils.util import strtobool
2221
from functools import wraps
2322
from importlib import import_module
2423
from typing import Any, Callable, cast, TypeVar
@@ -35,6 +34,21 @@
3534
except ImportError:
3635
from torch._dynamo import is_compiling
3736

37+
38+
def strtobool(val: Any) -> bool:
39+
"""Convert a string representation of truth to a boolean.
40+
41+
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.
42+
Raises ValueError if 'val' is anything else.
43+
"""
44+
val = val.lower()
45+
if val in ("y", "yes", "t", "true", "on", "1"):
46+
return True
47+
if val in ("n", "no", "f", "false", "off", "0"):
48+
return False
49+
raise ValueError(f"Invalid truth value {val!r}")
50+
51+
3852
LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO")
3953
logger = logging.getLogger("torchrl")
4054
logger.setLevel(getattr(logging, LOGGING_LEVEL))

torchrl/data/map/tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def make_node(
122122
return cls(
123123
count=torch.zeros(()),
124124
wins=torch.zeros(()),
125-
node=data.exclude("action", "next"),
125+
node_data=data.exclude("action", "next"),
126126
rollout=rollout,
127127
subtree=subtree,
128128
device=device,

0 commit comments

Comments
 (0)