Skip to content

[Feature] vllm wrapper #2830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ dependencies:
- pyyaml
- scipy
- hydra-core
- transformers<4.42.0
- transformers
- datasets
- vllm
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,24 @@ fi
# submodules
git submodule sync && git submodule update --init --recursive

printf "Installing PyTorch with cu128"
if [[ "$TORCH_VERSION" == "nightly" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U
else
pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu
else
pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128
fi
else
printf "Failed to install pytorch"
exit 1
fi
# We skip pytorch install due to vllm requirements
#printf "Installing PyTorch with cu128"
#if [[ "$TORCH_VERSION" == "nightly" ]]; then
# if [ "${CU_VERSION:-}" == cpu ] ; then
# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U
# else
# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U
# fi
#elif [[ "$TORCH_VERSION" == "stable" ]]; then
# if [ "${CU_VERSION:-}" == cpu ] ; then
# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu
# else
# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128
# fi
#else
# printf "Failed to install pytorch"
# exit 1
#fi

# install tensordict
if [[ "$RELEASE" == 0 ]]; then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ python -c "import transformers, datasets"

python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips

python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_actors.py -k llm --instafail -v --durations 200 --capture no --error-for-skips --runslow

python .github/unittest/helpers/coverage_run_parallel.py examples/rlhf/train_rlhf.py \
sys.device=cuda:0 sys.ref_device=cuda:0 \
model.name_or_path=gpt2 train.max_epochs=2 \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: RLHF Tests on Linux
name: LLM Tests on Linux

on:
pull_request:
Expand Down Expand Up @@ -50,7 +50,7 @@ jobs:
export TF_CPP_MIN_LOG_LEVEL=0
export TD_GET_DEFAULTS_TO_NONE=1

bash .github/unittest/linux_libs/scripts_rlhf/setup_env.sh
bash .github/unittest/linux_libs/scripts_rlhf/install.sh
bash .github/unittest/linux_libs/scripts_rlhf/run_test.sh
bash .github/unittest/linux_libs/scripts_rlhf/post_process.sh
bash .github/unittest/linux_libs/scripts_llm/setup_env.sh
bash .github/unittest/linux_libs/scripts_llm/install.sh
bash .github/unittest/linux_libs/scripts_llm/run_test.sh
bash .github/unittest/linux_libs/scripts_llm/post_process.sh
2 changes: 1 addition & 1 deletion examples/rlhf/models/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ def init_actor_critic(model_cfg, sys_cfg):
critic = model.get_value_operator()
critic_head = model.get_value_head()

return actor, VmapModule(critic), critic_head, base_model
return actor, VmapModule(critic, mock=True), critic_head, base_model
253 changes: 232 additions & 21 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,27 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import importlib.util
import os

import pytest
import torch

from tensordict import NonTensorStack, TensorDict
from tensordict.nn import CompositeDistribution, TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor

from torch import distributions as dist, nn
from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
from torchrl.data.llm import LLMData
from torchrl.data.llm.dataset import _has_transformers
from torchrl.modules import from_hf_transformers, MLP, SafeModule, TanhDelta, TanhNormal
from torchrl.modules import (
from_hf_transformers,
from_vllm,
MLP,
SafeModule,
TanhDelta,
TanhNormal,
)
from torchrl.modules.tensordict_module.actors import (
_process_action_space_spec,
ActorValueOperator,
Expand All @@ -37,6 +45,8 @@
from _utils_internal import get_default_devices
from mocking_classes import NestedCountingEnv

_has_vllm = importlib.util.find_spec("vllm") is not None


@pytest.mark.parametrize(
"log_prob_key",
Expand Down Expand Up @@ -908,52 +918,253 @@ def test_lmhead_actorvalueoperator(device):


@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies")
class TestTransformerActor:
@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies")
class TestLLMActor:
@pytest.mark.parametrize(
"from_text, generate, tokens, attention_mask",
"from_text, generate, return_log_probs, tokens, attention_mask",
[
(True, True, None, None),
(True, False, None, None),
(True, True, True, None, None),
(True, True, False, None, None),
(True, False, None, None, None),
(
False,
True,
True,
torch.randint(1024, (1, 10)),
torch.ones(1, 10, dtype=torch.int64),
),
(False, True, True, torch.randint(1024, (1, 10)), None),
(
False,
True,
False,
torch.randint(1024, (1, 10)),
torch.ones(1, 10, dtype=torch.int64),
),
(False, True, torch.randint(1024, (1, 10)), None),
(False, True, False, torch.randint(1024, (1, 10)), None),
],
)
def test_from_hf_transformers(self, from_text, generate, tokens, attention_mask):
from torchrl.data.llm import LLMData
def test_from_hf_transformers(
self, from_text, generate, return_log_probs, tokens, attention_mask
):
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel

model_name = "distilbert-base-uncased" # or "minilm" or "albert-tiny"
# Load the model and tokenizer
# model = AutoModel.from_pretrained(model_name)
# tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel(GPT2Config())

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

m = from_hf_transformers(
model, tokenizer=tokenizer, from_text=from_text, generate=generate
model,
tokenizer=tokenizer,
from_text=from_text,
generate=generate,
return_log_probs=return_log_probs,
)
self._run_check(
m,
tokens,
attention_mask,
generate,
return_log_probs,
from_text,
has_logits=True,
)

@pytest.mark.parametrize(
"from_text, generate, return_log_probs, tokens, attention_mask",
[
(True, True, True, None, None),
(True, True, False, None, None),
(True, False, None, None, None),
(
False,
True,
True,
torch.randint(1024, (1, 10)),
torch.ones(1, 10, dtype=torch.int64),
),
(False, True, True, torch.randint(1024, (1, 10)), None),
(
False,
True,
False,
torch.randint(1024, (1, 10)),
torch.ones(1, 10, dtype=torch.int64),
),
(False, True, False, torch.randint(1024, (1, 10)), None),
],
)
def test_from_vllm(
self, from_text, generate, return_log_probs, tokens, attention_mask
):
from vllm import LLM

model = LLM(model="facebook/opt-125m")
m = from_vllm(
model,
from_text=from_text,
generate=generate,
return_log_probs=return_log_probs,
)
self._run_check(
m,
tokens,
attention_mask,
generate,
return_log_probs,
from_text,
has_logits=False,
)

def _make_data(
self,
m,
tokens,
attention_mask,
generate,
from_text,
has_logits,
text_response=None,
tokens_response=None,
):
lp_kwargs = {}
if from_text:
tdin = LLMData(text=NonTensorStack("a text"), batch_size=1)
if not generate:
text_response = (
NonTensorStack(" and another text that follows")
if text_response is None
else text_response
)
if not isinstance(text_response, NonTensorStack):
if isinstance(text_response, list):
text_response = NonTensorStack(*text_response)
else:
text_response = NonTensorStack(text_response)
lp_kwargs.update({"text_response": text_response})
tdin = LLMData(text=NonTensorStack("a text"), **lp_kwargs, batch_size=1)
else:
tdin = LLMData(tokens=tokens, attention_mask=attention_mask, batch_size=1)
if not generate:
if tokens_response is None:
shape_response = tokens.shape
shape_response = shape_response[:-1] + (shape_response[-1] * 2,)
tokens_response = torch.randint(1024, shape_response)
lp_kwargs.update({"tokens_response": tokens_response})
tdin = LLMData(
tokens=tokens, attention_mask=attention_mask, **lp_kwargs, batch_size=1
)
return tdin

def _run_check(
self,
m,
tokens,
attention_mask,
generate,
return_log_probs,
from_text,
has_logits,
):
tdin = self._make_data(
m, tokens, attention_mask, generate, from_text, has_logits
)
if from_text and generate:
assert tdin.text_response is None
elif from_text and not generate:
assert tdin.text_response is not None

td = m(tdin)
assert td is tdin
assert isinstance(td, LLMData)
if from_text and generate:
assert td.text_response is not None
if generate and (attention_mask is not None or from_text):
assert td.attention_mask is not None, (generate, generate, from_text)
else:
assert td.text_response is None
if attention_mask is not None or from_text:
assert td.attention_mask is not None
else:
assert td.attention_mask is None
assert td.attention_mask is None, (generate, from_text)
if not generate:
assert td.text_response is None
assert td.tokens_response is None
# logprobs are computed on text response of tokens_response
assert td.text_response is not None or td.tokens_response is not None
assert td.log_probs is not None
assert td.logits is not None
if has_logits:
assert td.logits is not None
if generate:
if return_log_probs:
assert td.log_probs is not None
assert td.log_probs.shape[-2] == td.tokens_response.shape[-1]
else:
assert td.log_probs is None

# Test the shapes
assert td.tokens_response is not None, (generate, has_logits, from_text)

# If from text and not generating, the tokens are not returned for now
if not (from_text and not generate):
assert td.tokens_response.shape[:-1] == td.tokens.shape[:-1]
# The convention is that the response only has new tokens
assert (
td.tokens_response[..., : td.tokens.shape[-1]]
!= td.tokens[..., : td.tokens_response.shape[-1]]
).any(), (generate, from_text)

@pytest.mark.parametrize(
"from_text, tokens, attention_mask",
[
(True, None, None),
(
False,
torch.randint(1024, (1, 10)),
torch.ones(1, 10, dtype=torch.int64),
),
(False, torch.randint(1024, (1, 10)), None),
],
)
def test_from_vllm_logprobs(self, from_text, tokens, attention_mask):
from vllm import LLM

model = LLM(model="facebook/opt-125m")
m_generate = from_vllm(
model, from_text=from_text, generate=True, return_log_probs=True
)
m_logprobs = from_vllm(model, from_text=from_text, generate=False)
self._check_lps(
m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False
)

def _check_lps(
self,
model_generate,
model_logprobs,
tokens,
attention_mask,
from_text,
has_logits,
):
# Checks that the log-probs gathered with generate=False equate those with generate=True
tdin_genetate = self._make_data(
model_generate, tokens, attention_mask, True, from_text, has_logits
)
td_generate = model_generate(tdin_genetate)
tdin_logprobs = self._make_data(
model_logprobs,
tokens,
attention_mask,
False,
from_text,
has_logits,
tokens_response=td_generate.tokens_response,
text_response=td_generate.text_response,
)
td_logprobs = model_logprobs(tdin_logprobs)
torch.testing.assert_close(
td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2
)


if __name__ == "__main__":
Expand Down
Loading
Loading