Skip to content

Commit 0b3f7a4

Browse files
committed
[Feature] transformers policy
ghstack-source-id: 6f509ae Pull Request resolved: #2825
1 parent 4caa157 commit 0b3f7a4

File tree

4 files changed

+219
-19
lines changed

4 files changed

+219
-19
lines changed

torchrl/data/postprocs/postprocs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch import nn
1313

1414

15-
1615
def _get_reward(
1716
gamma: float,
1817
reward: torch.Tensor,
@@ -367,6 +366,7 @@ def __init__(
367366
discount: float = 1.0,
368367
):
369368
from torchrl.objectives.value.functional import reward2go
369+
370370
super().__init__()
371371
self.in_keys = [unravel_key(reward_key), unravel_key(done_key)]
372372
if reward_key_out is None:

torchrl/envs/custom/llm.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def __init__(
8080
self._batch_locked = False
8181
else:
8282
self._batch_locked = True
83-
super().__init__(device=device, batch_size=() if batch_size is None else (batch_size,))
83+
super().__init__(
84+
device=device, batch_size=() if batch_size is None else (batch_size,)
85+
)
8486
self.str2str = str2str
8587
self.vocab_size = vocab_size
8688
self.observation_key = unravel_key(token_key)
@@ -92,29 +94,21 @@ def __init__(
9294
# self.action_key = unravel_key(action_key)
9395
if str2str:
9496
self.full_observation_spec_unbatched = Composite(
95-
{
96-
token_key: NonTensor(
97-
example_data="a string", batched=True, shape=()
98-
)
99-
}
97+
{token_key: NonTensor(example_data="a string", batched=True, shape=())}
10098
)
10199
self.full_action_spec_unbatched = Composite(
102100
{action_key: NonTensor(example_data="a string", batched=True, shape=())}
103101
)
104102
else:
105103
if vocab_size is None:
106104
observation_spec = {
107-
token_key: Unbounded(
108-
shape=(-1,), dtype=torch.int64, device=device
109-
)
110-
}
105+
token_key: Unbounded(shape=(-1,), dtype=torch.int64, device=device)
106+
}
111107
if attention_key is not None:
112108
observation_spec[attention_key] = Unbounded(
113-
shape=(-1,), dtype=torch.int64, device=device
114-
)
115-
self.full_observation_spec_unbatched = Composite(
116-
observation_spec
117-
)
109+
shape=(-1,), dtype=torch.int64, device=device
110+
)
111+
self.full_observation_spec_unbatched = Composite(observation_spec)
118112
self.full_action_spec_unbatched = Composite(
119113
{
120114
action_key: Unbounded(
@@ -325,7 +319,13 @@ def _make_next_obs(
325319
if self.attention_key is not None:
326320
attention_mask = tensordict.get(self.attention_key)
327321
n = action.shape[-1] - attention_mask.shape[-1]
328-
attention_mask = torch.cat([attention_mask, attention_mask.new_ones(attention_mask.shape[:-1] + (n,))], -1)
322+
attention_mask = torch.cat(
323+
[
324+
attention_mask,
325+
attention_mask.new_ones(attention_mask.shape[:-1] + (n,)),
326+
],
327+
-1,
328+
)
329329
nex_td.set(self.attention_key, attention_mask)
330330
return nex_td
331331

@@ -384,7 +384,7 @@ def _make_next_obs(
384384

385385
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
386386
# We should have an observation by this time, if not raise an exception
387-
print('tensordict', tensordict)
387+
print("tensordict", tensordict)
388388
if tensordict is None or self.observation_key not in tensordict.keys(
389389
isinstance(self.observation_key, tuple)
390390
):

torchrl/envs/transforms/rlhf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
461461
raise ValueError(
462462
f"Unrecognized data type: {type(data)} with keys {self.data_keys}."
463463
)
464-
print('out', out)
464+
print("out", out)
465465
if self.use_buffer:
466466
if not out.ndim:
467467
out = out.unsqueeze(0)

torchrl/modules/llm/transformers.py

+200
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# TODO: lazy imports
7+
8+
import torch
9+
import transformers
10+
from tensordict import NestedKey, TensorDict, TensorDictBase
11+
from tensordict.nn import (
12+
TensorDictModule as Mod,
13+
TensorDictModuleBase,
14+
TensorDictSequential as Seq,
15+
WrapModule,
16+
)
17+
from transformers import AutoModelForCausalLM, AutoTokenizer
18+
19+
20+
def _maybe_clear_device(td):
21+
if td.device is None:
22+
return td
23+
return td.set(NonTensorData("_source_device"), td.device).clear_device_()
24+
25+
26+
def _maybe_set_device(td):
27+
device = td.pop("_source_device", None)
28+
if device is None:
29+
return td
30+
elif isinstance(device, NonTensorData):
31+
device: torch.device = device.data
32+
return td.to(device)
33+
34+
35+
def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase:
36+
# TODO: how do we avoid getting these?
37+
del td["tokens_out", "past_key_values"]
38+
scores = dict(td["tokens_out", "scores"].items())
39+
scores = torch.stack(
40+
[scores[str(k)] for k in range(len(scores))], 1
41+
) # shape (B, seq-len, vocab_size)
42+
logits = scores - scores.logsumexp(dim=-1, keepdim=True)
43+
td["logits"] = scores
44+
del td["tokens_out", "scores"]
45+
seq_len = scores.shape[1]
46+
tokens = td["tokens_out", "sequences"][..., -seq_len:] # shape (B, seq-len)
47+
log_probs = logits.gather(-1, tokens.unsqueeze(-1))
48+
td["log_probs"] = log_probs
49+
return td
50+
51+
52+
def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase:
53+
# TODO: how do we avoid getting these?
54+
del td["forward", "past_key_values"]
55+
scores = td["forward", "logits"]
56+
logits = scores - scores.logsumexp(dim=-1, keepdim=True)
57+
td["logits"] = scores
58+
del td["forward"]
59+
scores.shape[1]
60+
tokens = td["tokens_in", "input_ids"]
61+
log_probs = logits.gather(-1, tokens.unsqueeze(-1))
62+
td["log_probs"] = log_probs
63+
return td
64+
65+
66+
def from_hf_transformers(
67+
model: transformers.modeling_utils.PreTrainedModel,
68+
*,
69+
generate: bool = True,
70+
return_log_probs: bool = True,
71+
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None,
72+
from_text: bool = False,
73+
device: torch.device | None = None,
74+
text_key: NestedKey = "text",
75+
input_key: NestedKey = "input_ids",
76+
kwargs: dict | None = None,
77+
tokenizer_kwargs: dict | None = None,
78+
) -> TensorDictModuleBase:
79+
80+
# TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks
81+
82+
module_dict = {}
83+
if device:
84+
module_dict["clear_device"] = _maybe_clear_device
85+
if from_text:
86+
if not tokenizer_kwargs:
87+
tokenizer_kwargs = {}
88+
if not tokenizer_kwargs.setdefault("return_attention_mask", True):
89+
raise RuntimeError
90+
if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt":
91+
raise RuntimeError
92+
# TODO: add other paddings
93+
if tokenizer_kwargs.setdefault("padding", True) not in (True,):
94+
raise RuntimeError
95+
if tokenizer_kwargs.setdefault("padding_side", "left") != "left":
96+
raise RuntimeError
97+
98+
module_dict["encode"] = Mod(
99+
tokenizer,
100+
in_keys=[text_key],
101+
out_keys=["tokens_in"],
102+
method_kwargs=tokenizer_kwargs,
103+
strict=True,
104+
)
105+
if device:
106+
module_dict["to_dest_device"] = Mod(
107+
lambda tensor: tensor.to(device),
108+
in_keys=["tokens_in"],
109+
out_keys=["tokens_in"],
110+
strict=True,
111+
)
112+
113+
if generate:
114+
if not kwargs:
115+
kwargs = {}
116+
if return_log_probs:
117+
if not kwargs.setdefault("output_scores", True):
118+
raise RuntimeError
119+
if not kwargs.setdefault("return_dict_in_generate", True):
120+
raise RuntimeError
121+
if (
122+
kwargs.setdefault("tokenizer", tokenizer) is not tokenizer
123+
and tokenizer is not None
124+
):
125+
raise RuntimeError
126+
127+
module_dict["generate"] = Mod(
128+
model,
129+
method="generate",
130+
method_kwargs=kwargs,
131+
in_keys={
132+
"input_ids": ("tokens_in", "input_ids"),
133+
"attention_mask": ("tokens_in", "attention_mask"),
134+
},
135+
out_keys=["tokens_out"],
136+
out_to_in_map=True,
137+
strict=True,
138+
)
139+
if return_log_probs:
140+
module_dict["extract_log_probs"] = WrapModule(
141+
log_probs_from_scores,
142+
in_keys=[("tokens_out", "sequences"), ("tokens_out", "scores")],
143+
out_keys=["logits", "log_probs"],
144+
)
145+
if from_text:
146+
module_dict["decode"] = Mod(
147+
tokenizer.batch_decode,
148+
in_keys=[("tokens_out", "sequences")],
149+
out_keys=["action"],
150+
strict=True,
151+
)
152+
153+
else:
154+
if not kwargs:
155+
kwargs = {}
156+
if not kwargs.setdefault("return_dict", True):
157+
raise RuntimeError
158+
if not return_log_probs:
159+
raise RuntimeError
160+
module_dict["get_logprobs"] = Mod(
161+
model,
162+
method_kwargs=kwargs,
163+
in_keys={
164+
"input_ids": ("tokens_in", "input_ids"),
165+
"attention_mask": ("tokens_in", "attention_mask"),
166+
},
167+
out_keys=["forward"],
168+
out_to_in_map=True,
169+
strict=True,
170+
)
171+
module_dict["extract_log_probs"] = WrapModule(
172+
log_probs_from_logits,
173+
in_keys=[("tokens_in", "input_ids"), ("forward", "logits")],
174+
out_keys=["logits", "log_probs"],
175+
)
176+
if device:
177+
module_dict["to_source_device"] = _maybe_set_device
178+
return Seq(module_dict)
179+
180+
181+
if __name__ == "__main__":
182+
max_seq_length = 50000
183+
model_name = "Qwen/Qwen2.5-7B-Instruct"
184+
185+
model = AutoModelForCausalLM.from_pretrained(
186+
model_name, torch_dtype="auto", device_map="auto"
187+
)
188+
tokenizer = AutoTokenizer.from_pretrained(model_name)
189+
190+
tokenizer.padding_side = "left"
191+
192+
m = from_hf_transformers(
193+
model, tokenizer=tokenizer, from_text=True, device="cuda:0", generate=True
194+
)
195+
td = m(TensorDict(text="a text"))
196+
197+
m = from_hf_transformers(
198+
model, tokenizer=tokenizer, from_text=True, device="cuda:0", generate=False
199+
)
200+
td = m(TensorDict(text="a text"))

0 commit comments

Comments
 (0)