Skip to content

Commit 25c6c9c

Browse files
committed
amend
1 parent 4202b5d commit 25c6c9c

File tree

2 files changed

+13
-19
lines changed

2 files changed

+13
-19
lines changed

torchrl/envs/llm/datasets/ifeval.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ def default_spec(
3737
return Composite(
3838
key=Unbounded(shape=shape, dtype=torch.int64, device=device),
3939
instruction_id_list=NonTensor(
40-
shape=shape + (-1,),
40+
shape=shape,
4141
device=device,
4242
feature_dims=0,
4343
example_data=["punctuation:no_comma"],
4444
),
4545
kwargs=NonTensor(
46-
shape=shape + (-1,),
46+
shape=shape,
4747
device=device,
4848
feature_dims=0,
4949
example_data={
@@ -66,20 +66,14 @@ def default_spec(
6666
def _collate_fn(batch):
6767
batch = torch.stack([TensorDict.from_any(_batch) for _batch in batch])
6868
batch.rename_key_("prompt", "query")
69-
if batch.get("instruction_id_list").ndim == batch.ndim:
70-
# unsqueeze to ad a dimension - it must be a list
71-
torchrl_logger.info(
72-
f"Unsqueezing instruction_id_list from {batch.get('instruction_id_list').shape} to {batch.get('instruction_id_list').shape + (1,)}"
73-
)
74-
batch.set(
75-
"instruction_id_list", lazy_stack([batch.get("instruction_id_list")], -1)
76-
)
77-
if batch.get("kwargs").ndim == batch.ndim:
78-
# unsqueeze to ad a dimension - it must be a list
79-
torchrl_logger.info(
80-
f"Unsqueezing kwargs from {batch.get('kwargs').shape} to {batch.get('kwargs').shape + (1,)}"
81-
)
82-
batch.set("kwargs", lazy_stack([batch.get("kwargs")], -1))
69+
# we want instruction_id_list and kwargs to be lists, but not NonTensorStacks
70+
instruction_id_list = batch.get("instruction_id_list")
71+
# instruction_id_list should be a list of lists
72+
instruction_id_list = NonTensorStack(*[NonTensorData(item) for item in instruction_id_list])
73+
kwargs = batch.get("kwargs")
74+
kwargs = NonTensorStack(*[NonTensorData(item) for item in kwargs])
75+
batch.set("instruction_id_list", instruction_id_list)
76+
batch.set("kwargs", kwargs)
8377
torchrl_logger.info(f"Collated batch: {batch}")
8478
# we don't need a tensorclass here
8579
return batch

torchrl/envs/transforms/transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6531,9 +6531,9 @@ def _reset_func(
65316531
# *reset_val.keys(True)
65326532
# )
65336533
# )
6534-
tensordict_reset = reset_val.new_zeros(_reset.shape, empty_lazy=True)
6535-
print(f"tensordict_reset: {tensordict_reset}")
6536-
print(f"reset_val: {reset_val}")
6534+
tensordict_reset = reset_val.new_zeros(
6535+
_reset.shape, empty_lazy=True
6536+
)
65376537
tensordict_reset[_reset] = reset_val
65386538
else:
65396539
resets = self.default_value(reset=_reset)

0 commit comments

Comments
 (0)