5
5
6
6
# TODO: lazy imports
7
7
8
- from transformers import AutoModelForCausalLM , AutoTokenizer
9
- from tensordict .nn import TensorDictModule as Mod , TensorDictSequential as Seq , TensorDictModuleBase , WrapModule
10
- from tensordict import NestedKey , TensorDictBase , TensorDict
11
- import transformers
12
8
import torch
9
+ import transformers
10
+ from tensordict import NestedKey , TensorDict , TensorDictBase
11
+ from tensordict .nn import (
12
+ TensorDictModule as Mod ,
13
+ TensorDictModuleBase ,
14
+ TensorDictSequential as Seq ,
15
+ WrapModule ,
16
+ )
17
+ from transformers import AutoModelForCausalLM , AutoTokenizer
18
+
13
19
14
20
def _maybe_clear_device (td ):
15
21
if td .device is None :
@@ -30,7 +36,9 @@ def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase:
30
36
# TODO: how do we avoid getting these?
31
37
del td ["tokens_out" , "past_key_values" ]
32
38
scores = dict (td ["tokens_out" , "scores" ].items ())
33
- scores = torch .stack ([scores [str (k )] for k in range (len (scores ))], 1 ) # shape (B, seq-len, vocab_size)
39
+ scores = torch .stack (
40
+ [scores [str (k )] for k in range (len (scores ))], 1
41
+ ) # shape (B, seq-len, vocab_size)
34
42
logits = scores - scores .logsumexp (dim = - 1 , keepdim = True )
35
43
td ["logits" ] = scores
36
44
del td ["tokens_out" , "scores" ]
@@ -40,33 +48,34 @@ def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase:
40
48
td ["log_probs" ] = log_probs
41
49
return td
42
50
51
+
43
52
def log_probs_from_logits (td : TensorDictBase ) -> TensorDictBase :
44
53
# TODO: how do we avoid getting these?
45
54
del td ["forward" , "past_key_values" ]
46
55
scores = td ["forward" , "logits" ]
47
56
logits = scores - scores .logsumexp (dim = - 1 , keepdim = True )
48
57
td ["logits" ] = scores
49
58
del td ["forward" ]
50
- seq_len = scores .shape [1 ]
59
+ scores .shape [1 ]
51
60
tokens = td ["tokens_in" , "input_ids" ]
52
61
log_probs = logits .gather (- 1 , tokens .unsqueeze (- 1 ))
53
62
td ["log_probs" ] = log_probs
54
63
return td
55
64
56
65
57
66
def from_hf_transformers (
58
- model : transformers .modeling_utils .PreTrainedModel ,
59
- * ,
60
- generate : bool = True ,
61
- return_log_probs : bool = True ,
62
- tokenizer : transformers .tokenization_utils .PreTrainedTokenizer | None = None ,
63
- from_text : bool = False ,
64
- device : torch .device | None = None ,
65
- text_key : NestedKey = "text" ,
66
- input_key : NestedKey = "input_ids" ,
67
- kwargs : dict | None = None ,
68
- tokenizer_kwargs : dict | None = None ,
69
- ) -> TensorDictModuleBase :
67
+ model : transformers .modeling_utils .PreTrainedModel ,
68
+ * ,
69
+ generate : bool = True ,
70
+ return_log_probs : bool = True ,
71
+ tokenizer : transformers .tokenization_utils .PreTrainedTokenizer | None = None ,
72
+ from_text : bool = False ,
73
+ device : torch .device | None = None ,
74
+ text_key : NestedKey = "text" ,
75
+ input_key : NestedKey = "input_ids" ,
76
+ kwargs : dict | None = None ,
77
+ tokenizer_kwargs : dict | None = None ,
78
+ ) -> TensorDictModuleBase :
70
79
71
80
# TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks
72
81
@@ -98,7 +107,7 @@ def from_hf_transformers(
98
107
lambda tensor : tensor .to (device ),
99
108
in_keys = ["tokens_in" ],
100
109
out_keys = ["tokens_in" ],
101
- strict = True
110
+ strict = True ,
102
111
)
103
112
104
113
if generate :
@@ -109,7 +118,10 @@ def from_hf_transformers(
109
118
raise RuntimeError
110
119
if not kwargs .setdefault ("return_dict_in_generate" , True ):
111
120
raise RuntimeError
112
- if kwargs .setdefault ("tokenizer" , tokenizer ) is not tokenizer and tokenizer is not None :
121
+ if (
122
+ kwargs .setdefault ("tokenizer" , tokenizer ) is not tokenizer
123
+ and tokenizer is not None
124
+ ):
113
125
raise RuntimeError
114
126
115
127
module_dict ["generate" ] = Mod (
@@ -128,8 +140,8 @@ def from_hf_transformers(
128
140
module_dict ["extract_log_probs" ] = WrapModule (
129
141
log_probs_from_scores ,
130
142
in_keys = [("tokens_out" , "sequences" ), ("tokens_out" , "scores" )],
131
- out_keys = ["logits" , "log_probs" ]
132
- )
143
+ out_keys = ["logits" , "log_probs" ],
144
+ )
133
145
if from_text :
134
146
module_dict ["decode" ] = Mod (
135
147
tokenizer .batch_decode ,
@@ -159,8 +171,8 @@ def from_hf_transformers(
159
171
module_dict ["extract_log_probs" ] = WrapModule (
160
172
log_probs_from_logits ,
161
173
in_keys = [("tokens_in" , "input_ids" ), ("forward" , "logits" )],
162
- out_keys = ["logits" , "log_probs" ]
163
- )
174
+ out_keys = ["logits" , "log_probs" ],
175
+ )
164
176
if device :
165
177
module_dict ["to_source_device" ] = _maybe_set_device
166
178
return Seq (module_dict )
@@ -171,16 +183,18 @@ def from_hf_transformers(
171
183
model_name = "Qwen/Qwen2.5-7B-Instruct"
172
184
173
185
model = AutoModelForCausalLM .from_pretrained (
174
- model_name ,
175
- torch_dtype = "auto" ,
176
- device_map = "auto"
186
+ model_name , torch_dtype = "auto" , device_map = "auto"
177
187
)
178
188
tokenizer = AutoTokenizer .from_pretrained (model_name )
179
189
180
190
tokenizer .padding_side = "left"
181
191
182
- m = from_hf_transformers (model , tokenizer = tokenizer , from_text = True , device = "cuda:0" , generate = True )
192
+ m = from_hf_transformers (
193
+ model , tokenizer = tokenizer , from_text = True , device = "cuda:0" , generate = True
194
+ )
183
195
td = m (TensorDict (text = "a text" ))
184
196
185
- m = from_hf_transformers (model , tokenizer = tokenizer , from_text = True , device = "cuda:0" , generate = False )
197
+ m = from_hf_transformers (
198
+ model , tokenizer = tokenizer , from_text = True , device = "cuda:0" , generate = False
199
+ )
186
200
td = m (TensorDict (text = "a text" ))
0 commit comments