@@ -919,54 +919,108 @@ def test_lmhead_actorvalueoperator(device):
919
919
920
920
@pytest .mark .skipif (not _has_transformers , reason = "missing transformers dependencies" )
921
921
@pytest .mark .skipif (not _has_vllm , reason = "missing vllm dependencies" )
922
- class TestTransformerActor :
922
+ class TestLLMActor :
923
923
@pytest .mark .parametrize (
924
- "from_text, generate, tokens, attention_mask" ,
924
+ "from_text, generate, return_log_probs, tokens, attention_mask" ,
925
925
[
926
- (True , True , None , None ),
927
- (True , False , None , None ),
926
+ (True , True , True , None , None ),
927
+ (True , True , False , None , None ),
928
+ (True , False , None , None , None ),
928
929
(
929
930
False ,
930
931
True ,
932
+ True ,
931
933
torch .randint (1024 , (1 , 10 )),
932
934
torch .ones (1 , 10 , dtype = torch .int64 ),
933
935
),
934
- (False , True , torch .randint (1024 , (1 , 10 )), None ),
936
+ (False , True , True , torch .randint (1024 , (1 , 10 )), None ),
937
+ (
938
+ False ,
939
+ True ,
940
+ False ,
941
+ torch .randint (1024 , (1 , 10 )),
942
+ torch .ones (1 , 10 , dtype = torch .int64 ),
943
+ ),
944
+ (False , True , False , torch .randint (1024 , (1 , 10 )), None ),
935
945
],
936
946
)
937
- def test_from_hf_transformers (self , from_text , generate , tokens , attention_mask ):
947
+ def test_from_hf_transformers (
948
+ self , from_text , generate , return_log_probs , tokens , attention_mask
949
+ ):
938
950
from transformers import AutoTokenizer , GPT2Config , GPT2LMHeadModel
939
951
952
+ model_name = "distilbert-base-uncased" # or "minilm" or "albert-tiny"
953
+ # Load the model and tokenizer
954
+ # model = AutoModel.from_pretrained(model_name)
955
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
956
+
940
957
tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
941
- tokenizer .pad_token = tokenizer .eos_token
942
958
model = GPT2LMHeadModel (GPT2Config ())
959
+
960
+ tokenizer .pad_token = tokenizer .eos_token
943
961
tokenizer .padding_side = "left"
962
+
944
963
m = from_hf_transformers (
945
- model , tokenizer = tokenizer , from_text = from_text , generate = generate
964
+ model ,
965
+ tokenizer = tokenizer ,
966
+ from_text = from_text ,
967
+ generate = generate ,
968
+ return_log_probs = return_log_probs ,
969
+ )
970
+ self ._run_check (
971
+ m ,
972
+ tokens ,
973
+ attention_mask ,
974
+ generate ,
975
+ return_log_probs ,
976
+ from_text ,
977
+ has_logits = True ,
946
978
)
947
- self ._run_check (m , tokens , attention_mask , generate , from_text , has_logits = True )
948
979
949
980
@pytest .mark .parametrize (
950
- "from_text, generate, tokens, attention_mask" ,
981
+ "from_text, generate, return_log_probs, tokens, attention_mask" ,
951
982
[
952
- (True , True , None , None ),
953
- (True , False , None , None ),
983
+ (True , True , True , None , None ),
984
+ (True , True , False , None , None ),
985
+ (True , False , None , None , None ),
986
+ (
987
+ False ,
988
+ True ,
989
+ True ,
990
+ torch .randint (1024 , (1 , 10 )),
991
+ torch .ones (1 , 10 , dtype = torch .int64 ),
992
+ ),
993
+ (False , True , True , torch .randint (1024 , (1 , 10 )), None ),
954
994
(
955
995
False ,
956
996
True ,
997
+ False ,
957
998
torch .randint (1024 , (1 , 10 )),
958
999
torch .ones (1 , 10 , dtype = torch .int64 ),
959
1000
),
960
- (False , True , torch .randint (1024 , (1 , 10 )), None ),
1001
+ (False , True , False , torch .randint (1024 , (1 , 10 )), None ),
961
1002
],
962
1003
)
963
- def test_from_vllm (self , from_text , generate , tokens , attention_mask ):
1004
+ def test_from_vllm (
1005
+ self , from_text , generate , return_log_probs , tokens , attention_mask
1006
+ ):
964
1007
from vllm import LLM
965
1008
966
1009
model = LLM (model = "facebook/opt-125m" )
967
- m = from_vllm (model , from_text = from_text , generate = generate )
1010
+ m = from_vllm (
1011
+ model ,
1012
+ from_text = from_text ,
1013
+ generate = generate ,
1014
+ return_log_probs = return_log_probs ,
1015
+ )
968
1016
self ._run_check (
969
- m , tokens , attention_mask , generate , from_text , has_logits = False
1017
+ m ,
1018
+ tokens ,
1019
+ attention_mask ,
1020
+ generate ,
1021
+ return_log_probs ,
1022
+ from_text ,
1023
+ has_logits = False ,
970
1024
)
971
1025
972
1026
def _make_data (
@@ -1007,7 +1061,16 @@ def _make_data(
1007
1061
)
1008
1062
return tdin
1009
1063
1010
- def _run_check (self , m , tokens , attention_mask , generate , from_text , has_logits ):
1064
+ def _run_check (
1065
+ self ,
1066
+ m ,
1067
+ tokens ,
1068
+ attention_mask ,
1069
+ generate ,
1070
+ return_log_probs ,
1071
+ from_text ,
1072
+ has_logits ,
1073
+ ):
1011
1074
tdin = self ._make_data (
1012
1075
m , tokens , attention_mask , generate , from_text , has_logits
1013
1076
)
@@ -1024,13 +1087,19 @@ def _run_check(self, m, tokens, attention_mask, generate, from_text, has_logits)
1024
1087
if generate and (attention_mask is not None or from_text ):
1025
1088
assert td .attention_mask is not None , (generate , generate , from_text )
1026
1089
else :
1027
- assert td .attention_mask is None
1090
+ assert td .attention_mask is None , ( generate , from_text )
1028
1091
if not generate :
1029
1092
# logprobs are computed on text response of tokens_response
1030
1093
assert td .text_response is not None or td .tokens_response is not None
1031
1094
assert td .log_probs is not None
1032
1095
if has_logits :
1033
1096
assert td .logits is not None
1097
+ if generate :
1098
+ if return_log_probs :
1099
+ assert td .log_probs is not None
1100
+ assert td .log_probs .shape [- 2 ] == td .tokens_response .shape [- 1 ]
1101
+ else :
1102
+ assert td .log_probs is None
1034
1103
1035
1104
# Test the shapes
1036
1105
assert td .tokens_response is not None , (generate , has_logits , from_text )
@@ -1042,7 +1111,7 @@ def _run_check(self, m, tokens, attention_mask, generate, from_text, has_logits)
1042
1111
assert (
1043
1112
td .tokens_response [..., : td .tokens .shape [- 1 ]]
1044
1113
!= td .tokens [..., : td .tokens_response .shape [- 1 ]]
1045
- ).any ()
1114
+ ).any (), ( generate , from_text )
1046
1115
1047
1116
@pytest .mark .parametrize (
1048
1117
"from_text, tokens, attention_mask" ,
@@ -1060,7 +1129,9 @@ def test_from_vllm_logprobs(self, from_text, tokens, attention_mask):
1060
1129
from vllm import LLM
1061
1130
1062
1131
model = LLM (model = "facebook/opt-125m" )
1063
- m_generate = from_vllm (model , from_text = from_text , generate = True )
1132
+ m_generate = from_vllm (
1133
+ model , from_text = from_text , generate = True , return_log_probs = True
1134
+ )
1064
1135
m_logprobs = from_vllm (model , from_text = from_text , generate = False )
1065
1136
self ._check_lps (
1066
1137
m_generate , m_logprobs , tokens , attention_mask , from_text , has_logits = False
@@ -1091,7 +1162,6 @@ def _check_lps(
1091
1162
text_response = td_generate .text_response ,
1092
1163
)
1093
1164
td_logprobs = model_logprobs (tdin_logprobs )
1094
- print (td_generate .log_probs / td_logprobs .log_probs )
1095
1165
torch .testing .assert_close (
1096
1166
td_generate .log_probs , td_logprobs .log_probs , rtol = 1e-2 , atol = 1e-2
1097
1167
)
0 commit comments