Skip to content

Commit bef7a20

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 700c2ee commit bef7a20

File tree

11 files changed

+340
-81
lines changed

11 files changed

+340
-81
lines changed

.github/unittest/linux_libs/scripts_rlhf/environment.yml renamed to .github/unittest/linux_libs/scripts_llm/environment.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ dependencies:
1717
- pyyaml
1818
- scipy
1919
- hydra-core
20-
- transformers<4.42.0
20+
- transformers
2121
- datasets
22+
- vllm

.github/unittest/linux_libs/scripts_rlhf/install.sh renamed to .github/unittest/linux_libs/scripts_llm/install.sh

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,24 @@ fi
2626
# submodules
2727
git submodule sync && git submodule update --init --recursive
2828

29-
printf "Installing PyTorch with cu128"
30-
if [[ "$TORCH_VERSION" == "nightly" ]]; then
31-
if [ "${CU_VERSION:-}" == cpu ] ; then
32-
pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U
33-
else
34-
pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U
35-
fi
36-
elif [[ "$TORCH_VERSION" == "stable" ]]; then
37-
if [ "${CU_VERSION:-}" == cpu ] ; then
38-
pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu
39-
else
40-
pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128
41-
fi
42-
else
43-
printf "Failed to install pytorch"
44-
exit 1
45-
fi
29+
# We skip pytorch install due to vllm requirements
30+
#printf "Installing PyTorch with cu128"
31+
#if [[ "$TORCH_VERSION" == "nightly" ]]; then
32+
# if [ "${CU_VERSION:-}" == cpu ] ; then
33+
# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U
34+
# else
35+
# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U
36+
# fi
37+
#elif [[ "$TORCH_VERSION" == "stable" ]]; then
38+
# if [ "${CU_VERSION:-}" == cpu ] ; then
39+
# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu
40+
# else
41+
# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128
42+
# fi
43+
#else
44+
# printf "Failed to install pytorch"
45+
# exit 1
46+
#fi
4647

4748
# install tensordict
4849
if [[ "$RELEASE" == 0 ]]; then

.github/unittest/linux_libs/scripts_rlhf/run_test.sh renamed to .github/unittest/linux_libs/scripts_llm/run_test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ python -c "import transformers, datasets"
2424

2525
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips
2626

27+
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_actors.py -k llm --instafail -v --durations 200 --capture no --error-for-skips --runslow
28+
2729
python .github/unittest/helpers/coverage_run_parallel.py examples/rlhf/train_rlhf.py \
2830
sys.device=cuda:0 sys.ref_device=cuda:0 \
2931
model.name_or_path=gpt2 train.max_epochs=2 \

.github/workflows/test-linux-rlhf.yml renamed to .github/workflows/test-linux-llm.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: RLHF Tests on Linux
1+
name: LLM Tests on Linux
22

33
on:
44
pull_request:
@@ -50,7 +50,7 @@ jobs:
5050
export TF_CPP_MIN_LOG_LEVEL=0
5151
export TD_GET_DEFAULTS_TO_NONE=1
5252
53-
bash .github/unittest/linux_libs/scripts_rlhf/setup_env.sh
54-
bash .github/unittest/linux_libs/scripts_rlhf/install.sh
55-
bash .github/unittest/linux_libs/scripts_rlhf/run_test.sh
56-
bash .github/unittest/linux_libs/scripts_rlhf/post_process.sh
53+
bash .github/unittest/linux_libs/scripts_llm/setup_env.sh
54+
bash .github/unittest/linux_libs/scripts_llm/install.sh
55+
bash .github/unittest/linux_libs/scripts_llm/run_test.sh
56+
bash .github/unittest/linux_libs/scripts_llm/post_process.sh

test/test_actors.py

Lines changed: 91 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -919,54 +919,108 @@ def test_lmhead_actorvalueoperator(device):
919919

920920
@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies")
921921
@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies")
922-
class TestTransformerActor:
922+
class TestLLMActor:
923923
@pytest.mark.parametrize(
924-
"from_text, generate, tokens, attention_mask",
924+
"from_text, generate, return_log_probs, tokens, attention_mask",
925925
[
926-
(True, True, None, None),
927-
(True, False, None, None),
926+
(True, True, True, None, None),
927+
(True, True, False, None, None),
928+
(True, False, None, None, None),
928929
(
929930
False,
930931
True,
932+
True,
931933
torch.randint(1024, (1, 10)),
932934
torch.ones(1, 10, dtype=torch.int64),
933935
),
934-
(False, True, torch.randint(1024, (1, 10)), None),
936+
(False, True, True, torch.randint(1024, (1, 10)), None),
937+
(
938+
False,
939+
True,
940+
False,
941+
torch.randint(1024, (1, 10)),
942+
torch.ones(1, 10, dtype=torch.int64),
943+
),
944+
(False, True, False, torch.randint(1024, (1, 10)), None),
935945
],
936946
)
937-
def test_from_hf_transformers(self, from_text, generate, tokens, attention_mask):
947+
def test_from_hf_transformers(
948+
self, from_text, generate, return_log_probs, tokens, attention_mask
949+
):
938950
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel
939951

952+
model_name = "distilbert-base-uncased" # or "minilm" or "albert-tiny"
953+
# Load the model and tokenizer
954+
# model = AutoModel.from_pretrained(model_name)
955+
# tokenizer = AutoTokenizer.from_pretrained(model_name)
956+
940957
tokenizer = AutoTokenizer.from_pretrained("gpt2")
941-
tokenizer.pad_token = tokenizer.eos_token
942958
model = GPT2LMHeadModel(GPT2Config())
959+
960+
tokenizer.pad_token = tokenizer.eos_token
943961
tokenizer.padding_side = "left"
962+
944963
m = from_hf_transformers(
945-
model, tokenizer=tokenizer, from_text=from_text, generate=generate
964+
model,
965+
tokenizer=tokenizer,
966+
from_text=from_text,
967+
generate=generate,
968+
return_log_probs=return_log_probs,
969+
)
970+
self._run_check(
971+
m,
972+
tokens,
973+
attention_mask,
974+
generate,
975+
return_log_probs,
976+
from_text,
977+
has_logits=True,
946978
)
947-
self._run_check(m, tokens, attention_mask, generate, from_text, has_logits=True)
948979

949980
@pytest.mark.parametrize(
950-
"from_text, generate, tokens, attention_mask",
981+
"from_text, generate, return_log_probs, tokens, attention_mask",
951982
[
952-
(True, True, None, None),
953-
(True, False, None, None),
983+
(True, True, True, None, None),
984+
(True, True, False, None, None),
985+
(True, False, None, None, None),
986+
(
987+
False,
988+
True,
989+
True,
990+
torch.randint(1024, (1, 10)),
991+
torch.ones(1, 10, dtype=torch.int64),
992+
),
993+
(False, True, True, torch.randint(1024, (1, 10)), None),
954994
(
955995
False,
956996
True,
997+
False,
957998
torch.randint(1024, (1, 10)),
958999
torch.ones(1, 10, dtype=torch.int64),
9591000
),
960-
(False, True, torch.randint(1024, (1, 10)), None),
1001+
(False, True, False, torch.randint(1024, (1, 10)), None),
9611002
],
9621003
)
963-
def test_from_vllm(self, from_text, generate, tokens, attention_mask):
1004+
def test_from_vllm(
1005+
self, from_text, generate, return_log_probs, tokens, attention_mask
1006+
):
9641007
from vllm import LLM
9651008

9661009
model = LLM(model="facebook/opt-125m")
967-
m = from_vllm(model, from_text=from_text, generate=generate)
1010+
m = from_vllm(
1011+
model,
1012+
from_text=from_text,
1013+
generate=generate,
1014+
return_log_probs=return_log_probs,
1015+
)
9681016
self._run_check(
969-
m, tokens, attention_mask, generate, from_text, has_logits=False
1017+
m,
1018+
tokens,
1019+
attention_mask,
1020+
generate,
1021+
return_log_probs,
1022+
from_text,
1023+
has_logits=False,
9701024
)
9711025

9721026
def _make_data(
@@ -1007,7 +1061,16 @@ def _make_data(
10071061
)
10081062
return tdin
10091063

1010-
def _run_check(self, m, tokens, attention_mask, generate, from_text, has_logits):
1064+
def _run_check(
1065+
self,
1066+
m,
1067+
tokens,
1068+
attention_mask,
1069+
generate,
1070+
return_log_probs,
1071+
from_text,
1072+
has_logits,
1073+
):
10111074
tdin = self._make_data(
10121075
m, tokens, attention_mask, generate, from_text, has_logits
10131076
)
@@ -1024,13 +1087,19 @@ def _run_check(self, m, tokens, attention_mask, generate, from_text, has_logits)
10241087
if generate and (attention_mask is not None or from_text):
10251088
assert td.attention_mask is not None, (generate, generate, from_text)
10261089
else:
1027-
assert td.attention_mask is None
1090+
assert td.attention_mask is None, (generate, from_text)
10281091
if not generate:
10291092
# logprobs are computed on text response of tokens_response
10301093
assert td.text_response is not None or td.tokens_response is not None
10311094
assert td.log_probs is not None
10321095
if has_logits:
10331096
assert td.logits is not None
1097+
if generate:
1098+
if return_log_probs:
1099+
assert td.log_probs is not None
1100+
assert td.log_probs.shape[-2] == td.tokens_response.shape[-1]
1101+
else:
1102+
assert td.log_probs is None
10341103

10351104
# Test the shapes
10361105
assert td.tokens_response is not None, (generate, has_logits, from_text)
@@ -1042,7 +1111,7 @@ def _run_check(self, m, tokens, attention_mask, generate, from_text, has_logits)
10421111
assert (
10431112
td.tokens_response[..., : td.tokens.shape[-1]]
10441113
!= td.tokens[..., : td.tokens_response.shape[-1]]
1045-
).any()
1114+
).any(), (generate, from_text)
10461115

10471116
@pytest.mark.parametrize(
10481117
"from_text, tokens, attention_mask",
@@ -1060,7 +1129,9 @@ def test_from_vllm_logprobs(self, from_text, tokens, attention_mask):
10601129
from vllm import LLM
10611130

10621131
model = LLM(model="facebook/opt-125m")
1063-
m_generate = from_vllm(model, from_text=from_text, generate=True)
1132+
m_generate = from_vllm(
1133+
model, from_text=from_text, generate=True, return_log_probs=True
1134+
)
10641135
m_logprobs = from_vllm(model, from_text=from_text, generate=False)
10651136
self._check_lps(
10661137
m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False
@@ -1091,7 +1162,6 @@ def _check_lps(
10911162
text_response=td_generate.text_response,
10921163
)
10931164
td_logprobs = model_logprobs(tdin_logprobs)
1094-
print(td_generate.log_probs / td_logprobs.log_probs)
10951165
torch.testing.assert_close(
10961166
td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2
10971167
)

torchrl/envs/custom/llm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,11 @@ def __init__(
143143
# self.action_key = unravel_key(action_key)
144144
if str2str:
145145
self.full_observation_spec_unbatched = Composite(
146-
{self.str_key: NonTensor(example_data="a string", batched=True, shape=())}
146+
{
147+
self.str_key: NonTensor(
148+
example_data="a string", batched=True, shape=()
149+
)
150+
}
147151
)
148152
self.full_action_spec_unbatched = Composite(
149153
{action_key: NonTensor(example_data="a string", batched=True, shape=())}

0 commit comments

Comments
 (0)