3
3
# This source code is licensed under the MIT license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
import argparse
6
+ import importlib .util
6
7
import os
7
8
8
9
import pytest
9
10
import torch
10
-
11
11
from tensordict import NonTensorStack , TensorDict
12
12
from tensordict .nn import CompositeDistribution , TensorDictModule
13
13
from tensordict .nn .distributions import NormalParamExtractor
14
14
15
15
from torch import distributions as dist , nn
16
16
from torchrl .data import Binary , Bounded , Categorical , Composite , MultiOneHot , OneHot
17
+ from torchrl .data .llm import LLMData
17
18
from torchrl .data .llm .dataset import _has_transformers
18
- from torchrl .modules import from_hf_transformers , MLP , SafeModule , TanhDelta , TanhNormal
19
+ from torchrl .modules import (
20
+ from_hf_transformers ,
21
+ from_vllm ,
22
+ MLP ,
23
+ SafeModule ,
24
+ TanhDelta ,
25
+ TanhNormal ,
26
+ )
19
27
from torchrl .modules .tensordict_module .actors import (
20
28
_process_action_space_spec ,
21
29
ActorValueOperator ,
37
45
from _utils_internal import get_default_devices
38
46
from mocking_classes import NestedCountingEnv
39
47
48
+ _has_vllm = importlib .util .find_spec ("vllm" ) is not None
49
+
40
50
41
51
@pytest .mark .parametrize (
42
52
"log_prob_key" ,
@@ -908,6 +918,7 @@ def test_lmhead_actorvalueoperator(device):
908
918
909
919
910
920
@pytest .mark .skipif (not _has_transformers , reason = "missing transformers dependencies" )
921
+ @pytest .mark .skipif (not _has_vllm , reason = "missing vllm dependencies" )
911
922
class TestTransformerActor :
912
923
@pytest .mark .parametrize (
913
924
"from_text, generate, tokens, attention_mask" ,
@@ -924,7 +935,6 @@ class TestTransformerActor:
924
935
],
925
936
)
926
937
def test_from_hf_transformers (self , from_text , generate , tokens , attention_mask ):
927
- from torchrl .data .llm import LLMData
928
938
from transformers import AutoTokenizer , GPT2Config , GPT2LMHeadModel
929
939
930
940
tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
@@ -934,26 +944,157 @@ def test_from_hf_transformers(self, from_text, generate, tokens, attention_mask)
934
944
m = from_hf_transformers (
935
945
model , tokenizer = tokenizer , from_text = from_text , generate = generate
936
946
)
947
+ self ._run_check (m , tokens , attention_mask , generate , from_text , has_logits = True )
948
+
949
+ @pytest .mark .parametrize (
950
+ "from_text, generate, tokens, attention_mask" ,
951
+ [
952
+ (True , True , None , None ),
953
+ (True , False , None , None ),
954
+ (
955
+ False ,
956
+ True ,
957
+ torch .randint (1024 , (1 , 10 )),
958
+ torch .ones (1 , 10 , dtype = torch .int64 ),
959
+ ),
960
+ (False , True , torch .randint (1024 , (1 , 10 )), None ),
961
+ ],
962
+ )
963
+ def test_from_vllm (self , from_text , generate , tokens , attention_mask ):
964
+ from vllm import LLM
965
+
966
+ model = LLM (model = "facebook/opt-125m" )
967
+ m = from_vllm (model , from_text = from_text , generate = generate )
968
+ self ._run_check (
969
+ m , tokens , attention_mask , generate , from_text , has_logits = False
970
+ )
971
+
972
+ def _make_data (
973
+ self ,
974
+ m ,
975
+ tokens ,
976
+ attention_mask ,
977
+ generate ,
978
+ from_text ,
979
+ has_logits ,
980
+ text_response = None ,
981
+ tokens_response = None ,
982
+ ):
983
+ lp_kwargs = {}
937
984
if from_text :
938
- tdin = LLMData (text = NonTensorStack ("a text" ), batch_size = 1 )
985
+ if not generate :
986
+ text_response = (
987
+ NonTensorStack (" and another text that follows" )
988
+ if text_response is None
989
+ else text_response
990
+ )
991
+ if not isinstance (text_response , NonTensorStack ):
992
+ if isinstance (text_response , list ):
993
+ text_response = NonTensorStack (* text_response )
994
+ else :
995
+ text_response = NonTensorStack (text_response )
996
+ lp_kwargs .update ({"text_response" : text_response })
997
+ tdin = LLMData (text = NonTensorStack ("a text" ), ** lp_kwargs , batch_size = 1 )
939
998
else :
940
- tdin = LLMData (tokens = tokens , attention_mask = attention_mask , batch_size = 1 )
999
+ if not generate :
1000
+ if tokens_response is None :
1001
+ shape_response = tokens .shape
1002
+ shape_response = shape_response [:- 1 ] + (shape_response [- 1 ] * 2 ,)
1003
+ tokens_response = torch .randint (1024 , shape_response )
1004
+ lp_kwargs .update ({"tokens_response" : tokens_response })
1005
+ tdin = LLMData (
1006
+ tokens = tokens , attention_mask = attention_mask , ** lp_kwargs , batch_size = 1
1007
+ )
1008
+ return tdin
1009
+
1010
+ def _run_check (self , m , tokens , attention_mask , generate , from_text , has_logits ):
1011
+ tdin = self ._make_data (
1012
+ m , tokens , attention_mask , generate , from_text , has_logits
1013
+ )
1014
+ if from_text and generate :
1015
+ assert tdin .text_response is None
1016
+ elif from_text and not generate :
1017
+ assert tdin .text_response is not None
1018
+
941
1019
td = m (tdin )
942
1020
assert td is tdin
943
1021
assert isinstance (td , LLMData )
944
1022
if from_text and generate :
945
1023
assert td .text_response is not None
946
- else :
947
- assert td .text_response is None
948
- if attention_mask is not None or from_text :
949
- assert td .attention_mask is not None
1024
+ if generate and (attention_mask is not None or from_text ):
1025
+ assert td .attention_mask is not None , (generate , generate , from_text )
950
1026
else :
951
1027
assert td .attention_mask is None
952
1028
if not generate :
953
- assert td . text_response is None
954
- assert td .tokens_response is None
1029
+ # logprobs are computed on text response of tokens_response
1030
+ assert td .text_response is not None or td . tokens_response is not None
955
1031
assert td .log_probs is not None
956
- assert td .logits is not None
1032
+ if has_logits :
1033
+ assert td .logits is not None
1034
+
1035
+ # Test the shapes
1036
+ assert td .tokens_response is not None , (generate , has_logits , from_text )
1037
+
1038
+ # If from text and not generating, the tokens are not returned for now
1039
+ if not (from_text and not generate ):
1040
+ assert td .tokens_response .shape [:- 1 ] == td .tokens .shape [:- 1 ]
1041
+ # The convention is that the response only has new tokens
1042
+ assert (
1043
+ td .tokens_response [..., : td .tokens .shape [- 1 ]]
1044
+ != td .tokens [..., : td .tokens_response .shape [- 1 ]]
1045
+ ).any ()
1046
+
1047
+ @pytest .mark .parametrize (
1048
+ "from_text, tokens, attention_mask" ,
1049
+ [
1050
+ (True , None , None ),
1051
+ (
1052
+ False ,
1053
+ torch .randint (1024 , (1 , 10 )),
1054
+ torch .ones (1 , 10 , dtype = torch .int64 ),
1055
+ ),
1056
+ (False , torch .randint (1024 , (1 , 10 )), None ),
1057
+ ],
1058
+ )
1059
+ def test_from_vllm_logprobs (self , from_text , tokens , attention_mask ):
1060
+ from vllm import LLM
1061
+
1062
+ model = LLM (model = "facebook/opt-125m" )
1063
+ m_generate = from_vllm (model , from_text = from_text , generate = True )
1064
+ m_logprobs = from_vllm (model , from_text = from_text , generate = False )
1065
+ self ._check_lps (
1066
+ m_generate , m_logprobs , tokens , attention_mask , from_text , has_logits = False
1067
+ )
1068
+
1069
+ def _check_lps (
1070
+ self ,
1071
+ model_generate ,
1072
+ model_logprobs ,
1073
+ tokens ,
1074
+ attention_mask ,
1075
+ from_text ,
1076
+ has_logits ,
1077
+ ):
1078
+ # Checks that the log-probs gathered with generate=False equate those with generate=True
1079
+ tdin_genetate = self ._make_data (
1080
+ model_generate , tokens , attention_mask , True , from_text , has_logits
1081
+ )
1082
+ td_generate = model_generate (tdin_genetate )
1083
+ tdin_logprobs = self ._make_data (
1084
+ model_logprobs ,
1085
+ tokens ,
1086
+ attention_mask ,
1087
+ False ,
1088
+ from_text ,
1089
+ has_logits ,
1090
+ tokens_response = td_generate .tokens_response ,
1091
+ text_response = td_generate .text_response ,
1092
+ )
1093
+ td_logprobs = model_logprobs (tdin_logprobs )
1094
+ print (td_generate .log_probs / td_logprobs .log_probs )
1095
+ torch .testing .assert_close (
1096
+ td_generate .log_probs , td_logprobs .log_probs , rtol = 1e-2 , atol = 1e-2
1097
+ )
957
1098
958
1099
959
1100
if __name__ == "__main__" :
0 commit comments