@@ -37,13 +37,13 @@ def default_spec(
37
37
return Composite (
38
38
key = Unbounded (shape = shape , dtype = torch .int64 , device = device ),
39
39
instruction_id_list = NonTensor (
40
- shape = shape + ( - 1 ,) ,
40
+ shape = shape ,
41
41
device = device ,
42
42
feature_dims = 0 ,
43
43
example_data = ["punctuation:no_comma" ],
44
44
),
45
45
kwargs = NonTensor (
46
- shape = shape + ( - 1 ,) ,
46
+ shape = shape ,
47
47
device = device ,
48
48
feature_dims = 0 ,
49
49
example_data = {
@@ -66,20 +66,14 @@ def default_spec(
66
66
def _collate_fn (batch ):
67
67
batch = torch .stack ([TensorDict .from_any (_batch ) for _batch in batch ])
68
68
batch .rename_key_ ("prompt" , "query" )
69
- if batch .get ("instruction_id_list" ).ndim == batch .ndim :
70
- # unsqueeze to ad a dimension - it must be a list
71
- torchrl_logger .info (
72
- f"Unsqueezing instruction_id_list from { batch .get ('instruction_id_list' ).shape } to { batch .get ('instruction_id_list' ).shape + (1 ,)} "
73
- )
74
- batch .set (
75
- "instruction_id_list" , lazy_stack ([batch .get ("instruction_id_list" )], - 1 )
76
- )
77
- if batch .get ("kwargs" ).ndim == batch .ndim :
78
- # unsqueeze to ad a dimension - it must be a list
79
- torchrl_logger .info (
80
- f"Unsqueezing kwargs from { batch .get ('kwargs' ).shape } to { batch .get ('kwargs' ).shape + (1 ,)} "
81
- )
82
- batch .set ("kwargs" , lazy_stack ([batch .get ("kwargs" )], - 1 ))
69
+ # we want instruction_id_list and kwargs to be lists, but not NonTensorStacks
70
+ instruction_id_list = batch .get ("instruction_id_list" )
71
+ # instruction_id_list should be a list of lists
72
+ instruction_id_list = NonTensorStack (* [NonTensorData (item ) for item in instruction_id_list ])
73
+ kwargs = batch .get ("kwargs" )
74
+ kwargs = NonTensorStack (* [NonTensorData (item ) for item in kwargs ])
75
+ batch .set ("instruction_id_list" , instruction_id_list )
76
+ batch .set ("kwargs" , kwargs )
83
77
torchrl_logger .info (f"Collated batch: { batch } " )
84
78
# we don't need a tensorclass here
85
79
return batch
0 commit comments