@@ -84,20 +84,21 @@ def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[
8484        reward_tensor  =  torch .zeros_like (data .batch ["responses" ], dtype = torch .float32 )
8585        reward_metrics  =  defaultdict (list )
8686        response_ids  =  data .batch ["responses" ]
87-         response_length  =  data .batch ["response_mask" ]. sum ( dim = - 1 )
87+         response_length  =  torch . sum ( data .batch ["response_mask" ],  dim = - 1 )
8888        for  i  in  range (len (data )):
89-             valid_response_ids  =  response_ids [i ][: response_length [i ]]
89+             cur_response_length  =  int (response_length [i ].item ())  # avoid tensor indexing error 
90+             valid_response_ids  =  response_ids [i ][:cur_response_length ]
9091            response_str  =  self .tokenizer .decode (
9192                valid_response_ids , skip_special_tokens = self .config .skip_special_tokens 
9293            )
9394            score  =  self .reward_fn (
9495                {
9596                    "response" : response_str ,
96-                     "response_length" : response_length [ i ] ,
97+                     "response_length" : cur_response_length ,
9798                    "ground_truth" : data .non_tensor_batch ["ground_truth" ][i ],
9899                }
99100            )
100-             reward_tensor [i , response_length [ i ]  -  1 ] =  score ["overall" ]
101+             reward_tensor [i , cur_response_length  -  1 ] =  score ["overall" ]
101102            for  key , value  in  score .items ():
102103                reward_metrics [key ].append (value )
103104
@@ -110,16 +111,17 @@ class BatchFunctionRewardManager(FunctionRewardManager):
110111    def  compute_reward (self , data : DataProto ) ->  Tuple [torch .Tensor , Dict [str , List [float ]]]:
111112        reward_inputs  =  []
112113        response_ids  =  data .batch ["responses" ]
113-         response_length  =  data .batch ["response_mask" ]. sum ( dim = - 1 )
114+         response_length  =  torch . sum ( data .batch ["response_mask" ],  dim = - 1 )
114115        for  i  in  range (len (data )):
115-             valid_response_ids  =  response_ids [i ][: response_length [i ]]
116+             cur_response_length  =  int (response_length [i ].item ())  # avoid tensor indexing error 
117+             valid_response_ids  =  response_ids [i ][:cur_response_length ]
116118            response_str  =  self .tokenizer .decode (
117119                valid_response_ids , skip_special_tokens = self .config .skip_special_tokens 
118120            )
119121            reward_inputs .append (
120122                {
121123                    "response" : response_str ,
122-                     "response_length" : response_length [ i ] ,
124+                     "response_length" : cur_response_length ,
123125                    "ground_truth" : data .non_tensor_batch ["ground_truth" ][i ],
124126                }
125127            )
@@ -128,7 +130,8 @@ def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[
128130        reward_tensor  =  torch .zeros_like (data .batch ["responses" ], dtype = torch .float32 )
129131        reward_metrics  =  defaultdict (list )
130132        for  i , score  in  enumerate (scores ):
131-             reward_tensor [i , response_length [i ] -  1 ] =  score ["overall" ]
133+             cur_response_length  =  int (response_length [i ].item ())  # avoid tensor indexing error 
134+             reward_tensor [i , cur_response_length  -  1 ] =  score ["overall" ]
132135            for  key , value  in  score .items ():
133136                reward_metrics [key ].append (value )
134137
0 commit comments