Skip to content

Commit 98db6d0

Browse files
author
Vincent Moens
committed
[Feature] vllm wrapper
ghstack-source-id: 88fe78a Pull Request resolved: #2830
1 parent 2ce6b7c commit 98db6d0

File tree

6 files changed

+520
-16
lines changed

6 files changed

+520
-16
lines changed

test/test_actors.py

Lines changed: 153 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,27 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import argparse
6+
import importlib.util
67
import os
78

89
import pytest
910
import torch
10-
1111
from tensordict import NonTensorStack, TensorDict
1212
from tensordict.nn import CompositeDistribution, TensorDictModule
1313
from tensordict.nn.distributions import NormalParamExtractor
1414

1515
from torch import distributions as dist, nn
1616
from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
17+
from torchrl.data.llm import LLMData
1718
from torchrl.data.llm.dataset import _has_transformers
18-
from torchrl.modules import from_hf_transformers, MLP, SafeModule, TanhDelta, TanhNormal
19+
from torchrl.modules import (
20+
from_hf_transformers,
21+
from_vllm,
22+
MLP,
23+
SafeModule,
24+
TanhDelta,
25+
TanhNormal,
26+
)
1927
from torchrl.modules.tensordict_module.actors import (
2028
_process_action_space_spec,
2129
ActorValueOperator,
@@ -37,6 +45,8 @@
3745
from _utils_internal import get_default_devices
3846
from mocking_classes import NestedCountingEnv
3947

48+
_has_vllm = importlib.util.find_spec("vllm") is not None
49+
4050

4151
@pytest.mark.parametrize(
4252
"log_prob_key",
@@ -908,6 +918,7 @@ def test_lmhead_actorvalueoperator(device):
908918

909919

910920
@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies")
921+
@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies")
911922
class TestTransformerActor:
912923
@pytest.mark.parametrize(
913924
"from_text, generate, tokens, attention_mask",
@@ -924,7 +935,6 @@ class TestTransformerActor:
924935
],
925936
)
926937
def test_from_hf_transformers(self, from_text, generate, tokens, attention_mask):
927-
from torchrl.data.llm import LLMData
928938
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel
929939

930940
tokenizer = AutoTokenizer.from_pretrained("gpt2")
@@ -934,26 +944,157 @@ def test_from_hf_transformers(self, from_text, generate, tokens, attention_mask)
934944
m = from_hf_transformers(
935945
model, tokenizer=tokenizer, from_text=from_text, generate=generate
936946
)
947+
self._run_check(m, tokens, attention_mask, generate, from_text, has_logits=True)
948+
949+
@pytest.mark.parametrize(
950+
"from_text, generate, tokens, attention_mask",
951+
[
952+
(True, True, None, None),
953+
(True, False, None, None),
954+
(
955+
False,
956+
True,
957+
torch.randint(1024, (1, 10)),
958+
torch.ones(1, 10, dtype=torch.int64),
959+
),
960+
(False, True, torch.randint(1024, (1, 10)), None),
961+
],
962+
)
963+
def test_from_vllm(self, from_text, generate, tokens, attention_mask):
964+
from vllm import LLM
965+
966+
model = LLM(model="facebook/opt-125m")
967+
m = from_vllm(model, from_text=from_text, generate=generate)
968+
self._run_check(
969+
m, tokens, attention_mask, generate, from_text, has_logits=False
970+
)
971+
972+
def _make_data(
973+
self,
974+
m,
975+
tokens,
976+
attention_mask,
977+
generate,
978+
from_text,
979+
has_logits,
980+
text_response=None,
981+
tokens_response=None,
982+
):
983+
lp_kwargs = {}
937984
if from_text:
938-
tdin = LLMData(text=NonTensorStack("a text"), batch_size=1)
985+
if not generate:
986+
text_response = (
987+
NonTensorStack(" and another text that follows")
988+
if text_response is None
989+
else text_response
990+
)
991+
if not isinstance(text_response, NonTensorStack):
992+
if isinstance(text_response, list):
993+
text_response = NonTensorStack(*text_response)
994+
else:
995+
text_response = NonTensorStack(text_response)
996+
lp_kwargs.update({"text_response": text_response})
997+
tdin = LLMData(text=NonTensorStack("a text"), **lp_kwargs, batch_size=1)
939998
else:
940-
tdin = LLMData(tokens=tokens, attention_mask=attention_mask, batch_size=1)
999+
if not generate:
1000+
if tokens_response is None:
1001+
shape_response = tokens.shape
1002+
shape_response = shape_response[:-1] + (shape_response[-1] * 2,)
1003+
tokens_response = torch.randint(1024, shape_response)
1004+
lp_kwargs.update({"tokens_response": tokens_response})
1005+
tdin = LLMData(
1006+
tokens=tokens, attention_mask=attention_mask, **lp_kwargs, batch_size=1
1007+
)
1008+
return tdin
1009+
1010+
def _run_check(self, m, tokens, attention_mask, generate, from_text, has_logits):
1011+
tdin = self._make_data(
1012+
m, tokens, attention_mask, generate, from_text, has_logits
1013+
)
1014+
if from_text and generate:
1015+
assert tdin.text_response is None
1016+
elif from_text and not generate:
1017+
assert tdin.text_response is not None
1018+
9411019
td = m(tdin)
9421020
assert td is tdin
9431021
assert isinstance(td, LLMData)
9441022
if from_text and generate:
9451023
assert td.text_response is not None
946-
else:
947-
assert td.text_response is None
948-
if attention_mask is not None or from_text:
949-
assert td.attention_mask is not None
1024+
if generate and (attention_mask is not None or from_text):
1025+
assert td.attention_mask is not None, (generate, generate, from_text)
9501026
else:
9511027
assert td.attention_mask is None
9521028
if not generate:
953-
assert td.text_response is None
954-
assert td.tokens_response is None
1029+
# logprobs are computed on text response of tokens_response
1030+
assert td.text_response is not None or td.tokens_response is not None
9551031
assert td.log_probs is not None
956-
assert td.logits is not None
1032+
if has_logits:
1033+
assert td.logits is not None
1034+
1035+
# Test the shapes
1036+
assert td.tokens_response is not None, (generate, has_logits, from_text)
1037+
1038+
# If from text and not generating, the tokens are not returned for now
1039+
if not (from_text and not generate):
1040+
assert td.tokens_response.shape[:-1] == td.tokens.shape[:-1]
1041+
# The convention is that the response only has new tokens
1042+
assert (
1043+
td.tokens_response[..., : td.tokens.shape[-1]]
1044+
!= td.tokens[..., : td.tokens_response.shape[-1]]
1045+
).any()
1046+
1047+
@pytest.mark.parametrize(
1048+
"from_text, tokens, attention_mask",
1049+
[
1050+
(True, None, None),
1051+
(
1052+
False,
1053+
torch.randint(1024, (1, 10)),
1054+
torch.ones(1, 10, dtype=torch.int64),
1055+
),
1056+
(False, torch.randint(1024, (1, 10)), None),
1057+
],
1058+
)
1059+
def test_from_vllm_logprobs(self, from_text, tokens, attention_mask):
1060+
from vllm import LLM
1061+
1062+
model = LLM(model="facebook/opt-125m")
1063+
m_generate = from_vllm(model, from_text=from_text, generate=True)
1064+
m_logprobs = from_vllm(model, from_text=from_text, generate=False)
1065+
self._check_lps(
1066+
m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False
1067+
)
1068+
1069+
def _check_lps(
1070+
self,
1071+
model_generate,
1072+
model_logprobs,
1073+
tokens,
1074+
attention_mask,
1075+
from_text,
1076+
has_logits,
1077+
):
1078+
# Checks that the log-probs gathered with generate=False equate those with generate=True
1079+
tdin_genetate = self._make_data(
1080+
model_generate, tokens, attention_mask, True, from_text, has_logits
1081+
)
1082+
td_generate = model_generate(tdin_genetate)
1083+
tdin_logprobs = self._make_data(
1084+
model_logprobs,
1085+
tokens,
1086+
attention_mask,
1087+
False,
1088+
from_text,
1089+
has_logits,
1090+
tokens_response=td_generate.tokens_response,
1091+
text_response=td_generate.text_response,
1092+
)
1093+
td_logprobs = model_logprobs(tdin_logprobs)
1094+
print(td_generate.log_probs / td_logprobs.log_probs)
1095+
torch.testing.assert_close(
1096+
td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2
1097+
)
9571098

9581099

9591100
if __name__ == "__main__":

torchrl/envs/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,9 @@ def make_shape(shape):
942942
)
943943
if is_tensor_collection(tensor) and not is_non_tensor(tensor)
944944
else NonTensor(
945-
shape=tensor.shape, example_data=tensor.data, device=tensor.device
945+
shape=tensor.shape,
946+
example_data=tensor.data,
947+
device=tensor.device,
946948
)
947949
if is_non_tensor(tensor)
948950
else Unbounded(

torchrl/modules/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
)
9494
from .utils import get_primers_from_module
9595
from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip
96-
from .llm import from_hf_transformers
96+
from .llm import from_hf_transformers, from_vllm
9797

9898
__all__ = [
9999
"Actor",
@@ -178,6 +178,7 @@
178178
"WorldModelWrapper",
179179
"distributions_maps",
180180
"from_hf_transformers",
181+
"from_vllm",
181182
"get_primers_from_module",
182183
"recurrent_mode",
183184
"reset_noise",

torchrl/modules/llm/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from .transformers_policy import from_hf_transformers
7+
from .vllm_policy import from_vllm
78

8-
__all__ = ["from_hf_transformers"]
9+
__all__ = ["from_hf_transformers", "from_vllm"]

torchrl/modules/llm/transformers_policy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ def from_hf_transformers(
121121
) -> TensorDictModuleBase:
122122
# TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks
123123

124-
125124
module_dict = {}
126125
if device:
127126
module_dict["clear_device"] = _maybe_clear_device

0 commit comments

Comments
 (0)