@@ -47,7 +47,6 @@ def status(self) -> TrainingStatus:
47
47
48
48
class ArborReinforceJob (ReinforceJob ):
49
49
DEFAULT_TRAIN_KWARGS = { # noqa: RUF012
50
- "update_interval" : 10 ,
51
50
"temperature" : 0.9 ,
52
51
"beta" : 0.04 ,
53
52
"num_iterations" : 1 ,
@@ -85,7 +84,6 @@ def __init__(self, lm: "LM", train_kwargs: GRPOTrainKwargs):
85
84
def initialize (self ):
86
85
# TODO(GRPO Team): Set provider job ID
87
86
num_generations = self .train_kwargs .get ("num_generations" )
88
- update_interval = self .train_kwargs .get ("update_interval" , self .DEFAULT_TRAIN_KWARGS ["update_interval" ])
89
87
temperature = self .train_kwargs .get ("temperature" , self .DEFAULT_TRAIN_KWARGS ["temperature" ])
90
88
beta = self .train_kwargs .get ("beta" , self .DEFAULT_TRAIN_KWARGS ["beta" ])
91
89
num_iterations = self .train_kwargs .get ("num_iterations" , self .DEFAULT_TRAIN_KWARGS ["num_iterations" ])
@@ -125,7 +123,6 @@ def initialize(self):
125
123
"model" : finetune_model ,
126
124
"suffix" : suffix ,
127
125
"num_generations" : num_generations ,
128
- "update_interval" : update_interval ,
129
126
"temperature" : temperature ,
130
127
"beta" : beta ,
131
128
"num_iterations" : num_iterations ,
@@ -161,7 +158,7 @@ def _run_grpo_step_one_group(
161
158
# api_key = self.lm.kwargs["api_key"]
162
159
163
160
finetune_model = ArborProvider ._remove_provider_prefix (self .lm .model )
164
- data = {"model" : finetune_model , "update_inference_model" : True , " batch" : train_group }
161
+ data = {"model" : finetune_model , "batch" : train_group }
165
162
url = f"{ api_base } fine_tuning/grpo/step"
166
163
headers = {"Content-Type" : "application/json" }
167
164
response = requests .post (url , headers = headers , json = data )
@@ -186,18 +183,6 @@ def step(self, train_data: List[GRPOGroup], train_data_format: Optional[Union[Tr
186
183
for group in train_data :
187
184
self ._run_grpo_step_one_group (group , train_data_format )
188
185
189
- def update_model (self ):
190
- api_base = self .lm .kwargs ["api_base" ]
191
-
192
- url = f"{ api_base } fine_tuning/grpo/update_model"
193
- headers = {"Content-Type" : "application/json" }
194
- response = requests .post (url , headers = headers )
195
- assert response .status_code == 200 , f"Failed to update model: { response .text } "
196
-
197
- response = response .json ()
198
- current_model = response ["current_model" ]
199
- self .lm .model = ArborProvider ._add_provider_prefix (current_model )
200
-
201
186
def save_checkpoint (self , checkpoint_name : str , score : Optional [float ] = None ):
202
187
api_base = self .lm .kwargs ["api_base" ]
203
188
url = f"{ api_base } fine_tuning/grpo/checkpoint"
@@ -254,9 +239,6 @@ def __init__(self):
254
239
@staticmethod
255
240
def launch (lm : "LM" , launch_kwargs : Optional [Dict [str , Any ]] = None ):
256
241
model = ArborProvider ._remove_provider_prefix (lm .model )
257
- # TODO: Handle this on the server side
258
- if model .startswith ("huggingface/" ):
259
- model = model [len ("huggingface/" ) :]
260
242
261
243
api_base = lm .kwargs ["api_base" ]
262
244
0 commit comments