Skip to content

Commit d370867

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent d5cdc48 commit d370867

File tree

4 files changed

+62
-48
lines changed

4 files changed

+62
-48
lines changed

torchrl/data/postprocs/postprocs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch import nn
1313

1414

15-
1615
def _get_reward(
1716
gamma: float,
1817
reward: torch.Tensor,
@@ -367,6 +366,7 @@ def __init__(
367366
discount: float = 1.0,
368367
):
369368
from torchrl.objectives.value.functional import reward2go
369+
370370
super().__init__()
371371
self.in_keys = [unravel_key(reward_key), unravel_key(done_key)]
372372
if reward_key_out is None:

torchrl/envs/custom/llm.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def __init__(
8080
self._batch_locked = False
8181
else:
8282
self._batch_locked = True
83-
super().__init__(device=device, batch_size=() if batch_size is None else (batch_size,))
83+
super().__init__(
84+
device=device, batch_size=() if batch_size is None else (batch_size,)
85+
)
8486
self.str2str = str2str
8587
self.vocab_size = vocab_size
8688
self.observation_key = unravel_key(token_key)
@@ -92,29 +94,21 @@ def __init__(
9294
# self.action_key = unravel_key(action_key)
9395
if str2str:
9496
self.full_observation_spec_unbatched = Composite(
95-
{
96-
token_key: NonTensor(
97-
example_data="a string", batched=True, shape=()
98-
)
99-
}
97+
{token_key: NonTensor(example_data="a string", batched=True, shape=())}
10098
)
10199
self.full_action_spec_unbatched = Composite(
102100
{action_key: NonTensor(example_data="a string", batched=True, shape=())}
103101
)
104102
else:
105103
if vocab_size is None:
106104
observation_spec = {
107-
token_key: Unbounded(
108-
shape=(-1,), dtype=torch.int64, device=device
109-
)
110-
}
105+
token_key: Unbounded(shape=(-1,), dtype=torch.int64, device=device)
106+
}
111107
if attention_key is not None:
112108
observation_spec[attention_key] = Unbounded(
113-
shape=(-1,), dtype=torch.int64, device=device
114-
)
115-
self.full_observation_spec_unbatched = Composite(
116-
observation_spec
117-
)
109+
shape=(-1,), dtype=torch.int64, device=device
110+
)
111+
self.full_observation_spec_unbatched = Composite(observation_spec)
118112
self.full_action_spec_unbatched = Composite(
119113
{
120114
action_key: Unbounded(
@@ -325,7 +319,13 @@ def _make_next_obs(
325319
if self.attention_key is not None:
326320
attention_mask = tensordict.get(self.attention_key)
327321
n = action.shape[-1] - attention_mask.shape[-1]
328-
attention_mask = torch.cat([attention_mask, attention_mask.new_ones(attention_mask.shape[:-1] + (n,))], -1)
322+
attention_mask = torch.cat(
323+
[
324+
attention_mask,
325+
attention_mask.new_ones(attention_mask.shape[:-1] + (n,)),
326+
],
327+
-1,
328+
)
329329
nex_td.set(self.attention_key, attention_mask)
330330
return nex_td
331331

@@ -384,7 +384,7 @@ def _make_next_obs(
384384

385385
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
386386
# We should have an observation by this time, if not raise an exception
387-
print('tensordict', tensordict)
387+
print("tensordict", tensordict)
388388
if tensordict is None or self.observation_key not in tensordict.keys(
389389
isinstance(self.observation_key, tuple)
390390
):

torchrl/envs/transforms/rlhf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
461461
raise ValueError(
462462
f"Unrecognized data type: {type(data)} with keys {self.data_keys}."
463463
)
464-
print('out', out)
464+
print("out", out)
465465
if self.use_buffer:
466466
if not out.ndim:
467467
out = out.unsqueeze(0)

torchrl/modules/llm/transformers.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,17 @@
55

66
# TODO: lazy imports
77

8-
from transformers import AutoModelForCausalLM, AutoTokenizer
9-
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq, TensorDictModuleBase, WrapModule
10-
from tensordict import NestedKey, TensorDictBase, TensorDict
11-
import transformers
128
import torch
9+
import transformers
10+
from tensordict import NestedKey, TensorDict, TensorDictBase
11+
from tensordict.nn import (
12+
TensorDictModule as Mod,
13+
TensorDictModuleBase,
14+
TensorDictSequential as Seq,
15+
WrapModule,
16+
)
17+
from transformers import AutoModelForCausalLM, AutoTokenizer
18+
1319

1420
def _maybe_clear_device(td):
1521
if td.device is None:
@@ -30,7 +36,9 @@ def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase:
3036
# TODO: how do we avoid getting these?
3137
del td["tokens_out", "past_key_values"]
3238
scores = dict(td["tokens_out", "scores"].items())
33-
scores = torch.stack([scores[str(k)] for k in range(len(scores))], 1) # shape (B, seq-len, vocab_size)
39+
scores = torch.stack(
40+
[scores[str(k)] for k in range(len(scores))], 1
41+
) # shape (B, seq-len, vocab_size)
3442
logits = scores - scores.logsumexp(dim=-1, keepdim=True)
3543
td["logits"] = scores
3644
del td["tokens_out", "scores"]
@@ -40,33 +48,34 @@ def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase:
4048
td["log_probs"] = log_probs
4149
return td
4250

51+
4352
def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase:
4453
# TODO: how do we avoid getting these?
4554
del td["forward", "past_key_values"]
4655
scores = td["forward", "logits"]
4756
logits = scores - scores.logsumexp(dim=-1, keepdim=True)
4857
td["logits"] = scores
4958
del td["forward"]
50-
seq_len = scores.shape[1]
59+
scores.shape[1]
5160
tokens = td["tokens_in", "input_ids"]
5261
log_probs = logits.gather(-1, tokens.unsqueeze(-1))
5362
td["log_probs"] = log_probs
5463
return td
5564

5665

5766
def from_hf_transformers(
58-
model: transformers.modeling_utils.PreTrainedModel,
59-
*,
60-
generate: bool = True,
61-
return_log_probs: bool = True,
62-
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None,
63-
from_text: bool = False,
64-
device: torch.device | None = None,
65-
text_key: NestedKey = "text",
66-
input_key: NestedKey = "input_ids",
67-
kwargs: dict | None = None,
68-
tokenizer_kwargs: dict | None = None,
69-
) -> TensorDictModuleBase:
67+
model: transformers.modeling_utils.PreTrainedModel,
68+
*,
69+
generate: bool = True,
70+
return_log_probs: bool = True,
71+
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None,
72+
from_text: bool = False,
73+
device: torch.device | None = None,
74+
text_key: NestedKey = "text",
75+
input_key: NestedKey = "input_ids",
76+
kwargs: dict | None = None,
77+
tokenizer_kwargs: dict | None = None,
78+
) -> TensorDictModuleBase:
7079

7180
# TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks
7281

@@ -98,7 +107,7 @@ def from_hf_transformers(
98107
lambda tensor: tensor.to(device),
99108
in_keys=["tokens_in"],
100109
out_keys=["tokens_in"],
101-
strict=True
110+
strict=True,
102111
)
103112

104113
if generate:
@@ -109,7 +118,10 @@ def from_hf_transformers(
109118
raise RuntimeError
110119
if not kwargs.setdefault("return_dict_in_generate", True):
111120
raise RuntimeError
112-
if kwargs.setdefault("tokenizer", tokenizer) is not tokenizer and tokenizer is not None:
121+
if (
122+
kwargs.setdefault("tokenizer", tokenizer) is not tokenizer
123+
and tokenizer is not None
124+
):
113125
raise RuntimeError
114126

115127
module_dict["generate"] = Mod(
@@ -128,8 +140,8 @@ def from_hf_transformers(
128140
module_dict["extract_log_probs"] = WrapModule(
129141
log_probs_from_scores,
130142
in_keys=[("tokens_out", "sequences"), ("tokens_out", "scores")],
131-
out_keys=["logits", "log_probs"]
132-
)
143+
out_keys=["logits", "log_probs"],
144+
)
133145
if from_text:
134146
module_dict["decode"] = Mod(
135147
tokenizer.batch_decode,
@@ -159,8 +171,8 @@ def from_hf_transformers(
159171
module_dict["extract_log_probs"] = WrapModule(
160172
log_probs_from_logits,
161173
in_keys=[("tokens_in", "input_ids"), ("forward", "logits")],
162-
out_keys=["logits", "log_probs"]
163-
)
174+
out_keys=["logits", "log_probs"],
175+
)
164176
if device:
165177
module_dict["to_source_device"] = _maybe_set_device
166178
return Seq(module_dict)
@@ -171,16 +183,18 @@ def from_hf_transformers(
171183
model_name = "Qwen/Qwen2.5-7B-Instruct"
172184

173185
model = AutoModelForCausalLM.from_pretrained(
174-
model_name,
175-
torch_dtype="auto",
176-
device_map="auto"
186+
model_name, torch_dtype="auto", device_map="auto"
177187
)
178188
tokenizer = AutoTokenizer.from_pretrained(model_name)
179189

180190
tokenizer.padding_side = "left"
181191

182-
m = from_hf_transformers(model, tokenizer=tokenizer, from_text=True, device="cuda:0", generate=True)
192+
m = from_hf_transformers(
193+
model, tokenizer=tokenizer, from_text=True, device="cuda:0", generate=True
194+
)
183195
td = m(TensorDict(text="a text"))
184196

185-
m = from_hf_transformers(model, tokenizer=tokenizer, from_text=True, device="cuda:0", generate=False)
197+
m = from_hf_transformers(
198+
model, tokenizer=tokenizer, from_text=True, device="cuda:0", generate=False
199+
)
186200
td = m(TensorDict(text="a text"))

0 commit comments

Comments
 (0)