Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class Unsloth{RLConfig_name}({RLConfig_name}):
metadata = {{'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}},
)
def __init__({RLConfig_arguments},
use_vision = False,
vllm_sampling_params = None,
unsloth_num_chunks = -1,
**kwargs,
Expand All @@ -142,6 +143,7 @@ def __init__({RLConfig_arguments},
super().__init__({RLConfig_call_args}{RLConfig_kwargs})
self.vllm_sampling_params = vllm_sampling_params
self.unsloth_num_chunks = unsloth_num_chunks
self.use_vision = use_vision
pass

{RLTrainer_extras}
Expand Down Expand Up @@ -233,6 +235,12 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):

# Edit bf16, fp16 by checking model's torch_dtype directly
extra_args = ""

# Add boolean for vision support
if "args" in call_args :
use_vision = "self.use_vision = args.use_vision\n"
extra_args += use_vision

if "args" in call_args and "model" in call_args:
mixed_precision = \
"use_bf16 = getattr(args, 'bf16', False)\n"\
Expand Down
Loading