Skip to content

Commit dae38d3

Browse files
authored
Remove unused GRPO endpoint (#8354)
1 parent de11fbb commit dae38d3

File tree

4 files changed

+1
-26
lines changed

4 files changed

+1
-26
lines changed

docs/docs/tutorials/rl_multihop/index.ipynb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,6 @@
246246
"\n",
247247
"# NOTE: Training on 6 GPUs.\n",
248248
"train_kwargs = {\n",
249-
" \"update_interval\": 3,\n",
250249
" \"per_device_train_batch_size\": 2,\n",
251250
" \"gradient_accumulation_steps\": 4,\n",
252251
" \"temperature\": 0.7,\n",

docs/docs/tutorials/rl_papillon/index.ipynb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,6 @@
275275
"\n",
276276
"# NOTE: Training on 3 GPUs.\n",
277277
"train_kwargs = {\n",
278-
" \"update_interval\": 3,\n",
279278
" \"per_device_train_batch_size\": 8,\n",
280279
" \"gradient_accumulation_steps\": 4,\n",
281280
" \"temperature\": 0.7,\n",

dspy/clients/lm_local_arbor.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def status(self) -> TrainingStatus:
4747

4848
class ArborReinforceJob(ReinforceJob):
4949
DEFAULT_TRAIN_KWARGS = { # noqa: RUF012
50-
"update_interval": 10,
5150
"temperature": 0.9,
5251
"beta": 0.04,
5352
"num_iterations": 1,
@@ -85,7 +84,6 @@ def __init__(self, lm: "LM", train_kwargs: GRPOTrainKwargs):
8584
def initialize(self):
8685
# TODO(GRPO Team): Set provider job ID
8786
num_generations = self.train_kwargs.get("num_generations")
88-
update_interval = self.train_kwargs.get("update_interval", self.DEFAULT_TRAIN_KWARGS["update_interval"])
8987
temperature = self.train_kwargs.get("temperature", self.DEFAULT_TRAIN_KWARGS["temperature"])
9088
beta = self.train_kwargs.get("beta", self.DEFAULT_TRAIN_KWARGS["beta"])
9189
num_iterations = self.train_kwargs.get("num_iterations", self.DEFAULT_TRAIN_KWARGS["num_iterations"])
@@ -125,7 +123,6 @@ def initialize(self):
125123
"model": finetune_model,
126124
"suffix": suffix,
127125
"num_generations": num_generations,
128-
"update_interval": update_interval,
129126
"temperature": temperature,
130127
"beta": beta,
131128
"num_iterations": num_iterations,
@@ -161,7 +158,7 @@ def _run_grpo_step_one_group(
161158
# api_key = self.lm.kwargs["api_key"]
162159

163160
finetune_model = ArborProvider._remove_provider_prefix(self.lm.model)
164-
data = {"model": finetune_model, "update_inference_model": True, "batch": train_group}
161+
data = {"model": finetune_model, "batch": train_group}
165162
url = f"{api_base}fine_tuning/grpo/step"
166163
headers = {"Content-Type": "application/json"}
167164
response = requests.post(url, headers=headers, json=data)
@@ -186,18 +183,6 @@ def step(self, train_data: List[GRPOGroup], train_data_format: Optional[Union[Tr
186183
for group in train_data:
187184
self._run_grpo_step_one_group(group, train_data_format)
188185

189-
def update_model(self):
190-
api_base = self.lm.kwargs["api_base"]
191-
192-
url = f"{api_base}fine_tuning/grpo/update_model"
193-
headers = {"Content-Type": "application/json"}
194-
response = requests.post(url, headers=headers)
195-
assert response.status_code == 200, f"Failed to update model: {response.text}"
196-
197-
response = response.json()
198-
current_model = response["current_model"]
199-
self.lm.model = ArborProvider._add_provider_prefix(current_model)
200-
201186
def save_checkpoint(self, checkpoint_name: str, score: Optional[float] = None):
202187
api_base = self.lm.kwargs["api_base"]
203188
url = f"{api_base}fine_tuning/grpo/checkpoint"
@@ -254,9 +239,6 @@ def __init__(self):
254239
@staticmethod
255240
def launch(lm: "LM", launch_kwargs: Optional[Dict[str, Any]] = None):
256241
model = ArborProvider._remove_provider_prefix(lm.model)
257-
# TODO: Handle this on the server side
258-
if model.startswith("huggingface/"):
259-
model = model[len("huggingface/") :]
260242

261243
api_base = lm.kwargs["api_base"]
262244

dspy/teleprompt/grpo.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -538,11 +538,6 @@ def compile(
538538

539539
job.step(train_data=train_data, train_data_format=TrainDataFormat.GRPO_CHAT)
540540

541-
for (lm, _), job in grpo_training_jobs.items():
542-
if (train_step_idx + 1) % self.train_kwargs[lm]["update_interval"] == 0 and train_step_idx != 0:
543-
logger.info(f"Current train step is {train_step_idx + 1}. Updating the model...")
544-
job.update_model()
545-
546541
logger.info(f"GRPO training step {train_step_idx + 1}/{self.num_train_steps} completed.")
547542

548543
self.report_validation_metrics(

0 commit comments

Comments
 (0)